diff --git a/.gitignore b/.gitignore
index 34939e3a97aaa..30b1e12bf1b03 100644
--- a/.gitignore
+++ b/.gitignore
@@ -5,6 +5,7 @@
*.ipr
*.iml
*.iws
+*.pyc
.idea/
.idea_modules/
sbt/*.jar
@@ -49,7 +50,9 @@ dependency-reduced-pom.xml
checkpoint
derby.log
dist/
-spark-*-bin.tar.gz
+dev/create-release/*txt
+dev/create-release/*final
+spark-*-bin-*.tgz
unit-tests.log
/lib/
rat-results.txt
diff --git a/.rat-excludes b/.rat-excludes
index 20e3372464386..769defbac11b7 100644
--- a/.rat-excludes
+++ b/.rat-excludes
@@ -44,6 +44,7 @@ SparkImports.scala
SparkJLineCompletion.scala
SparkJLineReader.scala
SparkMemberHandlers.scala
+SparkReplReporter.scala
sbt
sbt-launch-lib.bash
plugins.sbt
@@ -63,3 +64,4 @@ dist/*
logs
.*scalastyle-output.xml
.*dependency-reduced-pom.xml
+known_translations
diff --git a/LICENSE b/LICENSE
index f1732fb47afc0..0a42d389e4c3c 100644
--- a/LICENSE
+++ b/LICENSE
@@ -646,7 +646,8 @@ THE SOFTWARE.
========================================================================
For Scala Interpreter classes (all .scala files in repl/src/main/scala
-except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala):
+except for Main.Scala, SparkHelper.scala and ExecutorClassLoader.scala),
+and for SerializableMapWrapper in JavaUtils.scala:
========================================================================
Copyright (c) 2002-2013 EPFL
@@ -754,7 +755,7 @@ SUCH DAMAGE.
========================================================================
-For Timsort (core/src/main/java/org/apache/spark/util/collection/Sorter.java):
+For Timsort (core/src/main/java/org/apache/spark/util/collection/TimSort.java):
========================================================================
Copyright (C) 2008 The Android Open Source Project
@@ -771,6 +772,25 @@ See the License for the specific language governing permissions and
limitations under the License.
+========================================================================
+For LimitedInputStream
+ (network/common/src/main/java/org/apache/spark/network/util/LimitedInputStream.java):
+========================================================================
+Copyright (C) 2007 The Guava Authors
+
+Licensed under the Apache License, Version 2.0 (the "License");
+you may not use this file except in compliance with the License.
+You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing, software
+distributed under the License is distributed on an "AS IS" BASIS,
+WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+See the License for the specific language governing permissions and
+limitations under the License.
+
+
========================================================================
BSD-style licenses
========================================================================
diff --git a/README.md b/README.md
index 9916ac7b1ae8e..af02339578195 100644
--- a/README.md
+++ b/README.md
@@ -13,7 +13,8 @@ and Spark Streaming for stream processing.
## Online Documentation
You can find the latest Spark documentation, including a programming
-guide, on the [project web page](http://spark.apache.org/documentation.html).
+guide, on the [project web page](http://spark.apache.org/documentation.html)
+and [project wiki](https://cwiki.apache.org/confluence/display/SPARK).
This README file only contains basic setup instructions.
## Building Spark
@@ -25,7 +26,7 @@ To build Spark and its example programs, run:
(You do not need to do this if you downloaded a pre-built package.)
More detailed documentation is available from the project site, at
-["Building Spark with Maven"](http://spark.apache.org/docs/latest/building-with-maven.html).
+["Building Spark"](http://spark.apache.org/docs/latest/building-spark.html).
## Interactive Scala Shell
diff --git a/assembly/pom.xml b/assembly/pom.xml
index 31a01e4d8e1de..b0c9bca9b0e87 100644
--- a/assembly/pom.xml
+++ b/assembly/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.2.1-palantir2../pom.xml
@@ -66,22 +66,22 @@
org.apache.spark
- spark-repl_${scala.binary.version}
+ spark-streaming_${scala.binary.version}${project.version}org.apache.spark
- spark-streaming_${scala.binary.version}
+ spark-graphx_${scala.binary.version}${project.version}org.apache.spark
- spark-graphx_${scala.binary.version}
+ spark-sql_${scala.binary.version}${project.version}org.apache.spark
- spark-sql_${scala.binary.version}
+ spark-repl_${scala.binary.version}${project.version}
@@ -197,6 +197,11 @@
spark-hive_${scala.binary.version}${project.version}
+
+
+
+ hive-thriftserver
+ org.apache.sparkspark-hive-thriftserver_${scala.binary.version}
diff --git a/bagel/pom.xml b/bagel/pom.xml
index 93db0d5efda5f..4022d693dbeec 100644
--- a/bagel/pom.xml
+++ b/bagel/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.2.1-palantir2../pom.xml
diff --git a/bin/beeline.cmd b/bin/beeline.cmd
new file mode 100644
index 0000000000000..8293f311029dd
--- /dev/null
+++ b/bin/beeline.cmd
@@ -0,0 +1,21 @@
+@echo off
+
+rem
+rem Licensed to the Apache Software Foundation (ASF) under one or more
+rem contributor license agreements. See the NOTICE file distributed with
+rem this work for additional information regarding copyright ownership.
+rem The ASF licenses this file to You under the Apache License, Version 2.0
+rem (the "License"); you may not use this file except in compliance with
+rem the License. You may obtain a copy of the License at
+rem
+rem http://www.apache.org/licenses/LICENSE-2.0
+rem
+rem Unless required by applicable law or agreed to in writing, software
+rem distributed under the License is distributed on an "AS IS" BASIS,
+rem WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+rem See the License for the specific language governing permissions and
+rem limitations under the License.
+rem
+
+set SPARK_HOME=%~dp0..
+cmd /V /E /C %SPARK_HOME%\bin\spark-class.cmd org.apache.hive.beeline.BeeLine %*
diff --git a/bin/compute-classpath.sh b/bin/compute-classpath.sh
index 905bbaf99b374..049b4a8515d3b 100755
--- a/bin/compute-classpath.sh
+++ b/bin/compute-classpath.sh
@@ -20,8 +20,6 @@
# This script computes Spark's classpath and prints it to stdout; it's used by both the "run"
# script and the ExecutorRunner in standalone cluster mode.
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
@@ -36,7 +34,7 @@ else
CLASSPATH="$CLASSPATH:$FWDIR/conf"
fi
-ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SCALA_VERSION"
+ASSEMBLY_DIR="$FWDIR/assembly/target/scala-$SPARK_SCALA_VERSION"
if [ -n "$JAVA_HOME" ]; then
JAR_CMD="$JAVA_HOME/bin/jar"
@@ -48,19 +46,19 @@ fi
if [ -n "$SPARK_PREPEND_CLASSES" ]; then
echo "NOTE: SPARK_PREPEND_CLASSES is set, placing locally compiled Spark"\
"classes ahead of assembly." >&2
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/classes"
CLASSPATH="$CLASSPATH:$FWDIR/core/target/jars/*"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SCALA_VERSION/classes"
- CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive-thriftserver/target/scala-$SPARK_SCALA_VERSION/classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/yarn/stable/target/scala-$SPARK_SCALA_VERSION/classes"
fi
# Use spark-assembly jar from either RELEASE or assembly directory
@@ -70,22 +68,25 @@ else
assembly_folder="$ASSEMBLY_DIR"
fi
-num_jars="$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*\.jar" | wc -l)"
-if [ "$num_jars" -eq "0" ]; then
- echo "Failed to find Spark assembly in $assembly_folder"
- echo "You need to build Spark before running this program."
- exit 1
-fi
+num_jars=0
+
+for f in ${assembly_folder}/spark-assembly*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark assembly in $assembly_folder" 1>&2
+ echo "You need to build Spark before running this program." 1>&2
+ exit 1
+ fi
+ ASSEMBLY_JAR="$f"
+ num_jars=$((num_jars+1))
+done
+
if [ "$num_jars" -gt "1" ]; then
- jars_list=$(ls "$assembly_folder" | grep "spark-assembly.*hadoop.*.jar")
- echo "Found multiple Spark assembly jars in $assembly_folder:"
- echo "$jars_list"
- echo "Please remove all but one jar."
+ echo "Found multiple Spark assembly jars in $assembly_folder:" 1>&2
+ ls ${assembly_folder}/spark-assembly*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
-ASSEMBLY_JAR="$(ls "$assembly_folder"/spark-assembly*hadoop*.jar 2>/dev/null)"
-
# Verify that versions of java used to build the jars and run Spark are compatible
jar_error_check=$("$JAR_CMD" -tf "$ASSEMBLY_JAR" nonexistent/class/path 2>&1)
if [[ "$jar_error_check" =~ "invalid CEN header" ]]; then
@@ -123,15 +124,15 @@ fi
# Add test classes if we're running from SBT or Maven with SPARK_TESTING set to 1
if [[ $SPARK_TESTING == 1 ]]; then
- CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SCALA_VERSION/test-classes"
- CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/repl/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/mllib/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/bagel/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/graphx/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/streaming/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/catalyst/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/core/target/scala-$SPARK_SCALA_VERSION/test-classes"
+ CLASSPATH="$CLASSPATH:$FWDIR/sql/hive/target/scala-$SPARK_SCALA_VERSION/test-classes"
fi
# Add hadoop conf dir if given -- otherwise FileSystem.*, etc fail !
diff --git a/bin/load-spark-env.sh b/bin/load-spark-env.sh
index 6d4231b204595..356b3d49b2ffe 100644
--- a/bin/load-spark-env.sh
+++ b/bin/load-spark-env.sh
@@ -36,3 +36,23 @@ if [ -z "$SPARK_ENV_LOADED" ]; then
set +a
fi
fi
+
+# Setting SPARK_SCALA_VERSION if not already set.
+
+if [ -z "$SPARK_SCALA_VERSION" ]; then
+
+ ASSEMBLY_DIR2="$FWDIR/assembly/target/scala-2.11"
+ ASSEMBLY_DIR1="$FWDIR/assembly/target/scala-2.10"
+
+ if [[ -d "$ASSEMBLY_DIR2" && -d "$ASSEMBLY_DIR1" ]]; then
+ echo -e "Presence of build for both scala versions(SCALA 2.10 and SCALA 2.11) detected." 1>&2
+ echo -e 'Either clean one of them or, export SPARK_SCALA_VERSION=2.11 in spark-env.sh.' 1>&2
+ exit 1
+ fi
+
+ if [ -d "$ASSEMBLY_DIR2" ]; then
+ export SPARK_SCALA_VERSION="2.11"
+ else
+ export SPARK_SCALA_VERSION="2.10"
+ fi
+fi
diff --git a/bin/pyspark b/bin/pyspark
index 96f30a260a09e..0b4f695dd06dd 100755
--- a/bin/pyspark
+++ b/bin/pyspark
@@ -25,7 +25,7 @@ export SPARK_HOME="$FWDIR"
source "$FWDIR/bin/utils.sh"
-SCALA_VERSION=2.10
+source "$FWDIR"/bin/load-spark-env.sh
function usage() {
echo "Usage: ./bin/pyspark [options]" 1>&2
@@ -40,7 +40,7 @@ fi
# Exit if the user hasn't compiled Spark
if [ ! -f "$FWDIR/RELEASE" ]; then
# Exit if the user hasn't compiled Spark
- ls "$FWDIR"/assembly/target/scala-$SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
+ ls "$FWDIR"/assembly/target/scala-$SPARK_SCALA_VERSION/spark-assembly*hadoop*.jar >& /dev/null
if [[ $? != 0 ]]; then
echo "Failed to find Spark assembly in $FWDIR/assembly/target" 1>&2
echo "You need to build Spark before running this program" 1>&2
@@ -48,8 +48,6 @@ if [ ! -f "$FWDIR/RELEASE" ]; then
fi
fi
-. "$FWDIR"/bin/load-spark-env.sh
-
# In Spark <= 1.1, setting IPYTHON=1 would cause the driver to be launched using the `ipython`
# executable, while the worker would still be launched using PYSPARK_PYTHON.
#
@@ -134,7 +132,5 @@ if [[ "$1" =~ \.py$ ]]; then
gatherSparkSubmitOpts "$@"
exec "$FWDIR"/bin/spark-submit "${SUBMISSION_OPTS[@]}" "$primary" "${APPLICATION_OPTS[@]}"
else
- # PySpark shell requires special handling downstream
- export PYSPARK_SHELL=1
exec "$PYSPARK_DRIVER_PYTHON" $PYSPARK_DRIVER_PYTHON_OPTS
fi
diff --git a/bin/pyspark2.cmd b/bin/pyspark2.cmd
index 59415e9bdec2c..a542ec80b49d6 100644
--- a/bin/pyspark2.cmd
+++ b/bin/pyspark2.cmd
@@ -59,7 +59,6 @@ for /f %%i in ('echo %1^| findstr /R "\.py"') do (
)
if [%PYTHON_FILE%] == [] (
- set PYSPARK_SHELL=1
if [%IPYTHON%] == [1] (
ipython %IPYTHON_OPTS%
) else (
diff --git a/bin/run-example b/bin/run-example
index 34dd71c71880e..c567acf9a6b5c 100755
--- a/bin/run-example
+++ b/bin/run-example
@@ -17,12 +17,12 @@
# limitations under the License.
#
-SCALA_VERSION=2.10
-
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
export SPARK_HOME="$FWDIR"
EXAMPLES_DIR="$FWDIR"/examples
+. "$FWDIR"/bin/load-spark-env.sh
+
if [ -n "$1" ]; then
EXAMPLE_CLASS="$1"
shift
@@ -35,17 +35,32 @@ else
fi
if [ -f "$FWDIR/RELEASE" ]; then
- export SPARK_EXAMPLES_JAR="`ls "$FWDIR"/lib/spark-examples-*hadoop*.jar`"
-elif [ -e "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar ]; then
- export SPARK_EXAMPLES_JAR="`ls "$EXAMPLES_DIR"/target/scala-$SCALA_VERSION/spark-examples-*hadoop*.jar`"
+ JAR_PATH="${FWDIR}/lib"
+else
+ JAR_PATH="${EXAMPLES_DIR}/target/scala-${SPARK_SCALA_VERSION}"
fi
-if [[ -z "$SPARK_EXAMPLES_JAR" ]]; then
- echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
- echo "You need to build Spark before running this program" 1>&2
+JAR_COUNT=0
+
+for f in ${JAR_PATH}/spark-examples-*hadoop*.jar; do
+ if [[ ! -e "$f" ]]; then
+ echo "Failed to find Spark examples assembly in $FWDIR/lib or $FWDIR/examples/target" 1>&2
+ echo "You need to build Spark before running this program" 1>&2
+ exit 1
+ fi
+ SPARK_EXAMPLES_JAR="$f"
+ JAR_COUNT=$((JAR_COUNT+1))
+done
+
+if [ "$JAR_COUNT" -gt "1" ]; then
+ echo "Found multiple Spark examples assembly jars in ${JAR_PATH}" 1>&2
+ ls ${JAR_PATH}/spark-examples-*hadoop*.jar 1>&2
+ echo "Please remove all but one jar." 1>&2
exit 1
fi
+export SPARK_EXAMPLES_JAR
+
EXAMPLE_MASTER=${MASTER:-"local[*]"}
if [[ ! $EXAMPLE_CLASS == org.apache.spark.examples* ]]; then
diff --git a/bin/spark-class b/bin/spark-class
index 925367b0dd187..3e6c367f17f40 100755
--- a/bin/spark-class
+++ b/bin/spark-class
@@ -24,13 +24,12 @@ case "`uname`" in
CYGWIN*) cygwin=true;;
esac
-SCALA_VERSION=2.10
-
# Figure out where Spark is installed
FWDIR="$(cd "`dirname "$0"`"/..; pwd)"
# Export this as SPARK_HOME
export SPARK_HOME="$FWDIR"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
. "$FWDIR"/bin/load-spark-env.sh
@@ -120,17 +119,17 @@ fi
JAVA_OPTS="$JAVA_OPTS -Xms$OUR_JAVA_MEM -Xmx$OUR_JAVA_MEM"
# Load extra JAVA_OPTS from conf/java-opts, if it exists
-if [ -e "$FWDIR/conf/java-opts" ] ; then
- JAVA_OPTS="$JAVA_OPTS `cat "$FWDIR"/conf/java-opts`"
+if [ -e "$SPARK_CONF_DIR/java-opts" ] ; then
+ JAVA_OPTS="$JAVA_OPTS `cat "$SPARK_CONF_DIR"/java-opts`"
fi
# Attention: when changing the way the JAVA_OPTS are assembled, the change must be reflected in CommandUtils.scala!
TOOLS_DIR="$FWDIR"/tools
SPARK_TOOLS_JAR=""
-if [ -e "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
+if [ -e "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the SBT build
- export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SCALA_VERSION/spark-tools*[0-9Tg].jar`"
+ export SPARK_TOOLS_JAR="`ls "$TOOLS_DIR"/target/scala-$SPARK_SCALA_VERSION/spark-tools*[0-9Tg].jar`"
fi
if [ -e "$TOOLS_DIR"/target/spark-tools*[0-9Tg].jar ]; then
# Use the JAR from the Maven build
@@ -149,7 +148,7 @@ fi
if [[ "$1" =~ org.apache.spark.tools.* ]]; then
if test -z "$SPARK_TOOLS_JAR"; then
- echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SCALA_VERSION/" 1>&2
+ echo "Failed to find Spark Tools Jar in $FWDIR/tools/target/scala-$SPARK_SCALA_VERSION/" 1>&2
echo "You need to build Spark before running $1." 1>&2
exit 1
fi
diff --git a/bin/spark-shell b/bin/spark-shell
index 4a0670fc6c8aa..cca5aa0676123 100755
--- a/bin/spark-shell
+++ b/bin/spark-shell
@@ -45,6 +45,13 @@ source "$FWDIR"/bin/utils.sh
SUBMIT_USAGE_FUNCTION=usage
gatherSparkSubmitOpts "$@"
+# SPARK-4161: scala does not assume use of the java classpath,
+# so we need to add the "-Dscala.usejavacp=true" flag mnually. We
+# do this specifically for the Spark shell because the scala REPL
+# has its own class loader, and any additional classpath specified
+# through spark.driver.extraClassPath is not automatically propagated.
+SPARK_SUBMIT_OPTS="$SPARK_SUBMIT_OPTS -Dscala.usejavacp=true"
+
function main() {
if $cygwin; then
# Workaround for issue involving JLine and Cygwin
diff --git a/bin/spark-submit b/bin/spark-submit
index c557311b4b20e..216b92e411bbd 100755
--- a/bin/spark-submit
+++ b/bin/spark-submit
@@ -20,8 +20,13 @@
# NOTE: Any changes in this file must be reflected in SparkSubmitDriverBootstrapper.scala!
export SPARK_HOME="$(cd "`dirname "$0"`"/..; pwd)"
+export SPARK_CONF_DIR="${SPARK_CONF_DIR:-"$SPARK_HOME/conf"}"
+
ORIG_ARGS=("$@")
+# Set COLUMNS for progress bar
+export COLUMNS=`tput cols`
+
while (($#)); do
if [ "$1" = "--deploy-mode" ]; then
SPARK_SUBMIT_DEPLOY_MODE=$2
@@ -35,11 +40,16 @@ while (($#)); do
export SPARK_SUBMIT_CLASSPATH=$2
elif [ "$1" = "--driver-java-options" ]; then
export SPARK_SUBMIT_OPTS=$2
+ elif [ "$1" = "--master" ]; then
+ export MASTER=$2
fi
shift
done
-DEFAULT_PROPERTIES_FILE="$SPARK_HOME/conf/spark-defaults.conf"
+DEFAULT_PROPERTIES_FILE="$SPARK_CONF_DIR/spark-defaults.conf"
+if [ "$MASTER" == "yarn-cluster" ]; then
+ SPARK_SUBMIT_DEPLOY_MODE=cluster
+fi
export SPARK_SUBMIT_DEPLOY_MODE=${SPARK_SUBMIT_DEPLOY_MODE:-"client"}
export SPARK_SUBMIT_PROPERTIES_FILE=${SPARK_SUBMIT_PROPERTIES_FILE:-"$DEFAULT_PROPERTIES_FILE"}
diff --git a/bin/spark-submit2.cmd b/bin/spark-submit2.cmd
index cf6046d1547ad..4581264b586d0 100644
--- a/bin/spark-submit2.cmd
+++ b/bin/spark-submit2.cmd
@@ -24,13 +24,18 @@ set ORIG_ARGS=%*
rem Reset the values of all variables used
set SPARK_SUBMIT_DEPLOY_MODE=client
-set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
set SPARK_SUBMIT_DRIVER_MEMORY=
set SPARK_SUBMIT_LIBRARY_PATH=
set SPARK_SUBMIT_CLASSPATH=
set SPARK_SUBMIT_OPTS=
set SPARK_SUBMIT_BOOTSTRAP_DRIVER=
+if not "x%SPARK_CONF_DIR%"=="x" (
+ set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_CONF_DIR%\spark-defaults.conf
+) else (
+ set SPARK_SUBMIT_PROPERTIES_FILE=%SPARK_HOME%\conf\spark-defaults.conf
+)
+
:loop
if [%1] == [] goto continue
if [%1] == [--deploy-mode] (
@@ -45,11 +50,17 @@ if [%1] == [] goto continue
set SPARK_SUBMIT_CLASSPATH=%2
) else if [%1] == [--driver-java-options] (
set SPARK_SUBMIT_OPTS=%2
+ ) else if [%1] == [--master] (
+ set MASTER=%2
)
shift
goto loop
:continue
+if [%MASTER%] == [yarn-cluster] (
+ set SPARK_SUBMIT_DEPLOY_MODE=cluster
+)
+
rem For client mode, the driver will be launched in the same JVM that launches
rem SparkSubmit, so we may need to read the properties file for any extra class
rem paths, library paths, java options and memory early on. Otherwise, it will
diff --git a/conf/metrics.properties.template b/conf/metrics.properties.template
index 30bcab0c93302..96b6844f0aabb 100644
--- a/conf/metrics.properties.template
+++ b/conf/metrics.properties.template
@@ -77,8 +77,8 @@
# sample false Whether to show entire set of samples for histograms ('false' or 'true')
#
# * Default path is /metrics/json for all instances except the master. The master has two paths:
-# /metrics/aplications/json # App information
-# /metrics/master/json # Master information
+# /metrics/applications/json # App information
+# /metrics/master/json # Master information
# org.apache.spark.metrics.sink.GraphiteSink
# Name: Default: Description:
diff --git a/conf/spark-env.sh.template b/conf/spark-env.sh.template
index f8ffbf64278fb..0886b0276fb90 100755
--- a/conf/spark-env.sh.template
+++ b/conf/spark-env.sh.template
@@ -28,7 +28,7 @@
# - SPARK_YARN_DIST_FILES, Comma separated list of files to be distributed with the job.
# - SPARK_YARN_DIST_ARCHIVES, Comma separated list of archives to be distributed with the job.
-# Options for the daemons used in the standalone deploy mode:
+# Options for the daemons used in the standalone deploy mode
# - SPARK_MASTER_IP, to bind the master to a different IP address or hostname
# - SPARK_MASTER_PORT / SPARK_MASTER_WEBUI_PORT, to use non-default ports for the master
# - SPARK_MASTER_OPTS, to set config properties only for the master (e.g. "-Dx=y")
@@ -41,3 +41,10 @@
# - SPARK_HISTORY_OPTS, to set config properties only for the history server (e.g. "-Dx=y")
# - SPARK_DAEMON_JAVA_OPTS, to set config properties for all daemons (e.g. "-Dx=y")
# - SPARK_PUBLIC_DNS, to set the public dns name of the master or workers
+
+# Generic options for the daemons used in the standalone deploy mode
+# - SPARK_CONF_DIR Alternate conf dir. (Default: ${SPARK_HOME}/conf)
+# - SPARK_LOG_DIR Where log files are stored. (Default: ${SPARK_HOME}/logs)
+# - SPARK_PID_DIR Where the pid file is stored. (Default: /tmp)
+# - SPARK_IDENT_STRING A string representing this instance of spark. (Default: $USER)
+# - SPARK_NICENESS The scheduling priority for daemons. (Default: 0)
diff --git a/core/pom.xml b/core/pom.xml
index 41296e0eca330..43936efbbf3e0 100644
--- a/core/pom.xml
+++ b/core/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.2.1-palantir2../pom.xml
@@ -34,6 +34,34 @@
Spark Project Corehttp://spark.apache.org/
+
+ com.twitter
+ chill_${scala.binary.version}
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+
+
+ com.twitter
+ chill-java
+
+
+ org.ow2.asm
+ asm
+
+
+ org.ow2.asm
+ asm-commons
+
+
+ org.apache.hadoophadoop-client
@@ -46,12 +74,12 @@
org.apache.spark
- spark-network-common_2.10
+ spark-network-common_${scala.binary.version}${project.version}org.apache.spark
- spark-network-shuffle_2.10
+ spark-network-shuffle_${scala.binary.version}${project.version}
@@ -132,14 +160,6 @@
net.jpountz.lz4lz4
-
- com.twitter
- chill_${scala.binary.version}
-
-
- com.twitter
- chill-java
- org.roaringbitmapRoaringBitmap
@@ -309,14 +329,16 @@
org.scalatestscalatest-maven-plugin
-
-
- ${basedir}/..
- 1
- ${spark.classpath}
-
-
+
+
+ test
+
+ test
+
+
+
+
org.apache.maven.plugins
@@ -424,4 +446,5 @@
+
diff --git a/core/src/main/java/org/apache/spark/SparkJobInfo.java b/core/src/main/java/org/apache/spark/SparkJobInfo.java
index 4e3c983b1170a..e31c4401632a6 100644
--- a/core/src/main/java/org/apache/spark/SparkJobInfo.java
+++ b/core/src/main/java/org/apache/spark/SparkJobInfo.java
@@ -17,13 +17,15 @@
package org.apache.spark;
+import java.io.Serializable;
+
/**
* Exposes information about Spark Jobs.
*
* This interface is not designed to be implemented outside of Spark. We may add additional methods
* which may break binary compatibility with outside implementations.
*/
-public interface SparkJobInfo {
+public interface SparkJobInfo extends Serializable {
int jobId();
int[] stageIds();
JobExecutionStatus status();
diff --git a/core/src/main/java/org/apache/spark/SparkStageInfo.java b/core/src/main/java/org/apache/spark/SparkStageInfo.java
index 04e2247210ecc..b7d462abd72d6 100644
--- a/core/src/main/java/org/apache/spark/SparkStageInfo.java
+++ b/core/src/main/java/org/apache/spark/SparkStageInfo.java
@@ -17,15 +17,18 @@
package org.apache.spark;
+import java.io.Serializable;
+
/**
* Exposes information about Spark Stages.
*
* This interface is not designed to be implemented outside of Spark. We may add additional methods
* which may break binary compatibility with outside implementations.
*/
-public interface SparkStageInfo {
+public interface SparkStageInfo extends Serializable {
int stageId();
int currentAttemptId();
+ long submissionTime();
String name();
int numTasks();
int numActiveTasks();
diff --git a/core/src/main/java/org/apache/spark/api/java/function/package.scala b/core/src/main/java/org/apache/spark/api/java/function/package.scala
index 7f91de653a64a..0f9bac7164162 100644
--- a/core/src/main/java/org/apache/spark/api/java/function/package.scala
+++ b/core/src/main/java/org/apache/spark/api/java/function/package.scala
@@ -22,4 +22,4 @@ package org.apache.spark.api.java
* these interfaces to pass functions to various Java API methods for Spark. Please visit Spark's
* Java programming guide for more details.
*/
-package object function
\ No newline at end of file
+package object function
diff --git a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
index c5936b5038ac9..14ba37d7c9bd9 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/additional-metrics.js
@@ -26,24 +26,24 @@ $(function() {
// Switch the class of the arrow from open to closed.
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-open');
$(this).find('.expand-additional-metrics-arrow').toggleClass('arrow-closed');
-
- // If clicking caused the metrics to expand, automatically check all options for additional
- // metrics (don't trigger a click when collapsing metrics, because it leads to weird
- // toggling behavior).
- if (!$(additionalMetricsDiv).hasClass('collapsed')) {
- $(this).parent().find('input:checkbox:not(:checked)').trigger('click');
- }
});
- $("input:checkbox:not(:checked)").each(function() {
- var column = "table ." + $(this).attr("name");
- $(column).hide();
- });
+ stripeSummaryTable();
$("input:checkbox").click(function() {
var column = "table ." + $(this).attr("name");
$(column).toggle();
- stripeTables();
+ stripeSummaryTable();
+ });
+
+ $("#select-all-metrics").click(function() {
+ if (this.checked) {
+ // Toggle all un-checked options.
+ $('input:checkbox:not(:checked)').trigger('click');
+ } else {
+ // Toggle all checked options.
+ $('input:checkbox:checked').trigger('click');
+ }
});
// Trigger a click on the checkbox if a user clicks the label next to it.
diff --git a/core/src/main/resources/org/apache/spark/ui/static/table.js b/core/src/main/resources/org/apache/spark/ui/static/table.js
index 32187ba6e8df0..656147e40d13e 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/table.js
+++ b/core/src/main/resources/org/apache/spark/ui/static/table.js
@@ -15,21 +15,18 @@
* limitations under the License.
*/
-/* Adds background colors to stripe table rows. This is necessary (instead of using css or the
- * table striping provided by bootstrap) to appropriately stripe tables with hidden rows. */
-function stripeTables() {
- $("table.table-striped-custom").each(function() {
- $(this).find("tr:not(:hidden)").each(function (index) {
- if (index % 2 == 1) {
- $(this).css("background-color", "#f9f9f9");
- } else {
- $(this).css("background-color", "#ffffff");
- }
- });
+/* Adds background colors to stripe table rows in the summary table (on the stage page). This is
+ * necessary (instead of using css or the table striping provided by bootstrap) because the summary
+ * table has hidden rows.
+ *
+ * An ID selector (rather than a class selector) is used to ensure this runs quickly even on pages
+ * with thousands of task rows (ID selectors are much faster than class selectors). */
+function stripeSummaryTable() {
+ $("#task-summary-table").find("tr:not(:hidden)").each(function (index) {
+ if (index % 2 == 1) {
+ $(this).css("background-color", "#f9f9f9");
+ } else {
+ $(this).css("background-color", "#ffffff");
+ }
});
}
-
-/* Stripe all tables after pages finish loading. */
-$(function() {
- stripeTables();
-});
diff --git a/core/src/main/resources/org/apache/spark/ui/static/webui.css b/core/src/main/resources/org/apache/spark/ui/static/webui.css
index a2220e761ac98..5751964b792ce 100644
--- a/core/src/main/resources/org/apache/spark/ui/static/webui.css
+++ b/core/src/main/resources/org/apache/spark/ui/static/webui.css
@@ -120,6 +120,20 @@ pre {
border: none;
}
+.stacktrace-details {
+ max-height: 300px;
+ overflow-y: auto;
+ margin: 0;
+ transition: max-height 0.5s ease-out, padding 0.5s ease-out;
+}
+
+.stacktrace-details.collapsed {
+ max-height: 0;
+ padding-top: 0;
+ padding-bottom: 0;
+ border: none;
+}
+
span.expand-additional-metrics {
cursor: pointer;
}
@@ -154,3 +168,19 @@ span.additional-metric-title {
border-left: 5px solid black;
display: inline-block;
}
+
+.version {
+ line-height: 30px;
+ vertical-align: bottom;
+ font-size: 12px;
+ padding: 0;
+ margin: 0;
+ font-weight: bold;
+ color: #777;
+}
+
+/* Hide all additional metrics by default. This is done here rather than using JavaScript to
+ * avoid slow page loads for stage pages with large numbers (e.g., thousands) of tasks. */
+.scheduler_delay, .deserialization_time, .serialization_time, .getting_result_time {
+ display: none;
+}
diff --git a/core/src/main/scala/org/apache/spark/Accumulators.scala b/core/src/main/scala/org/apache/spark/Accumulators.scala
index 2301caafb07ff..6ef4ff5543b0a 100644
--- a/core/src/main/scala/org/apache/spark/Accumulators.scala
+++ b/core/src/main/scala/org/apache/spark/Accumulators.scala
@@ -18,6 +18,8 @@
package org.apache.spark
import java.io.{ObjectInputStream, Serializable}
+import java.util.concurrent.atomic.AtomicLong
+import java.lang.ThreadLocal
import scala.collection.generic.Growable
import scala.collection.mutable.Map
@@ -228,6 +230,7 @@ GrowableAccumulableParam[R <% Growable[T] with TraversableOnce[T] with Serializa
*/
class Accumulator[T](@transient initialValue: T, param: AccumulatorParam[T], name: Option[String])
extends Accumulable[T,T](initialValue, param, name) {
+
def this(initialValue: T, param: AccumulatorParam[T]) = this(initialValue, param, None)
}
@@ -246,13 +249,15 @@ trait AccumulatorParam[T] extends AccumulableParam[T, T] {
// TODO: The multi-thread support in accumulators is kind of lame; check
// if there's a more intuitive way of doing it right
-private object Accumulators {
+private[spark] object Accumulators {
// TODO: Use soft references? => need to make readObject work properly then
val originals = Map[Long, Accumulable[_, _]]()
- val localAccums = Map[Thread, Map[Long, Accumulable[_, _]]]()
+ val localAccums = new ThreadLocal[Map[Long, Accumulable[_, _]]]() {
+ override protected def initialValue() = Map[Long, Accumulable[_, _]]()
+ }
var lastId: Long = 0
- def newId: Long = synchronized {
+ def newId(): Long = synchronized {
lastId += 1
lastId
}
@@ -261,22 +266,21 @@ private object Accumulators {
if (original) {
originals(a.id) = a
} else {
- val accums = localAccums.getOrElseUpdate(Thread.currentThread, Map())
- accums(a.id) = a
+ localAccums.get()(a.id) = a
}
}
// Clear the local (non-original) accumulators for the current thread
def clear() {
synchronized {
- localAccums.remove(Thread.currentThread)
+ localAccums.get.clear
}
}
// Get the values of the local accumulators for the current thread (by ID)
def values: Map[Long, Any] = synchronized {
val ret = Map[Long, Any]()
- for ((id, accum) <- localAccums.getOrElse(Thread.currentThread, Map())) {
+ for ((id, accum) <- localAccums.get) {
ret(id) = accum.localValue
}
return ret
diff --git a/core/src/main/scala/org/apache/spark/Aggregator.scala b/core/src/main/scala/org/apache/spark/Aggregator.scala
index 79c9c451d273d..09eb9605fb799 100644
--- a/core/src/main/scala/org/apache/spark/Aggregator.scala
+++ b/core/src/main/scala/org/apache/spark/Aggregator.scala
@@ -34,7 +34,9 @@ case class Aggregator[K, V, C] (
mergeValue: (C, V) => C,
mergeCombiners: (C, C) => C) {
- private val externalSorting = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
+ // When spilling is enabled sorting will happen externally, but not necessarily with an
+ // ExternalSorter.
+ private val isSpillEnabled = SparkEnv.get.conf.getBoolean("spark.shuffle.spill", true)
@deprecated("use combineValuesByKey with TaskContext argument", "0.9.0")
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]]): Iterator[(K, C)] =
@@ -42,7 +44,7 @@ case class Aggregator[K, V, C] (
def combineValuesByKey(iter: Iterator[_ <: Product2[K, V]],
context: TaskContext): Iterator[(K, C)] = {
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kv: Product2[K, V] = null
val update = (hadValue: Boolean, oldValue: C) => {
@@ -71,9 +73,9 @@ case class Aggregator[K, V, C] (
combineCombinersByKey(iter, null)
def combineCombinersByKey(iter: Iterator[_ <: Product2[K, C]], context: TaskContext)
- : Iterator[(K, C)] =
+ : Iterator[(K, C)] =
{
- if (!externalSorting) {
+ if (!isSpillEnabled) {
val combiners = new AppendOnlyMap[K,C]
var kc: Product2[K, C] = null
val update = (hadValue: Boolean, oldValue: C) => {
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
new file mode 100644
index 0000000000000..a46a81eabd965
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationClient.scala
@@ -0,0 +1,42 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * A client that communicates with the cluster manager to request or kill executors.
+ */
+private[spark] trait ExecutorAllocationClient {
+
+ /**
+ * Request an additional number of executors from the cluster manager.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def requestExecutors(numAdditionalExecutors: Int): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executors.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutors(executorIds: Seq[String]): Boolean
+
+ /**
+ * Request that the cluster manager kill the specified executor.
+ * Return whether the request is acknowledged by the cluster manager.
+ */
+ def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+}
diff --git a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
index c11f1db0064fd..a0ee2a7cbb2a2 100644
--- a/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
+++ b/core/src/main/scala/org/apache/spark/ExecutorAllocationManager.scala
@@ -28,7 +28,9 @@ import org.apache.spark.scheduler._
* the scheduler queue is not drained in N seconds, then new executors are added. If the queue
* persists for another M seconds, then more executors are added and so on. The number added
* in each round increases exponentially from the previous round until an upper bound on the
- * number of executors has been reached.
+ * number of executors has been reached. The upper bound is based both on a configured property
+ * and on the number of tasks pending: the policy will never increase the number of executor
+ * requests past the number needed to handle all pending tasks.
*
* The rationale for the exponential increase is twofold: (1) Executors should be added slowly
* in the beginning in case the number of extra executors needed turns out to be small. Otherwise,
@@ -58,15 +60,19 @@ import org.apache.spark.scheduler._
* spark.dynamicAllocation.executorIdleTimeout (K) -
* If an executor has been idle for this duration, remove it
*/
-private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging {
- import ExecutorAllocationManager._
+private[spark] class ExecutorAllocationManager(
+ client: ExecutorAllocationClient,
+ listenerBus: LiveListenerBus,
+ conf: SparkConf)
+ extends Logging {
+
+ allocationManager =>
- private val conf = sc.conf
+ import ExecutorAllocationManager._
// Lower and upper bounds on the number of executors. These are required.
private val minNumExecutors = conf.getInt("spark.dynamicAllocation.minExecutors", -1)
private val maxNumExecutors = conf.getInt("spark.dynamicAllocation.maxExecutors", -1)
- verifyBounds()
// How long there must be backlogged tasks for before an addition is triggered
private val schedulerBacklogTimeout = conf.getLong(
@@ -77,9 +83,20 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
"spark.dynamicAllocation.sustainedSchedulerBacklogTimeout", schedulerBacklogTimeout)
// How long an executor must be idle for before it is removed
- private val removeThresholdSeconds = conf.getLong(
+ private val executorIdleTimeout = conf.getLong(
"spark.dynamicAllocation.executorIdleTimeout", 600)
+ // During testing, the methods to actually kill and add executors are mocked out
+ private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
+
+ // TODO: The default value of 1 for spark.executor.cores works right now because dynamic
+ // allocation is only supported for YARN and the default number of cores per executor in YARN is
+ // 1, but it might need to be attained differently for different cluster managers
+ private val tasksPerExecutor =
+ conf.getInt("spark.executor.cores", 1) / conf.getInt("spark.task.cpus", 1)
+
+ validateSettings()
+
// Number of executors to add in the next round
private var numExecutorsToAdd = 1
@@ -103,17 +120,17 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Polling loop interval (ms)
private val intervalMillis: Long = 100
- // Whether we are testing this class. This should only be used internally.
- private val testing = conf.getBoolean("spark.dynamicAllocation.testing", false)
-
// Clock used to schedule when executors should be added and removed
private var clock: Clock = new RealClock
+ // Listener for Spark events that impact the allocation policy
+ private val listener = new ExecutorAllocationListener
+
/**
- * Verify that the lower and upper bounds on the number of executors are valid.
+ * Verify that the settings specified through the config are valid.
* If not, throw an appropriate exception.
*/
- private def verifyBounds(): Unit = {
+ private def validateSettings(): Unit = {
if (minNumExecutors < 0 || maxNumExecutors < 0) {
throw new SparkException("spark.dynamicAllocation.{min/max}Executors must be set!")
}
@@ -124,6 +141,25 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
throw new SparkException(s"spark.dynamicAllocation.minExecutors ($minNumExecutors) must " +
s"be less than or equal to spark.dynamicAllocation.maxExecutors ($maxNumExecutors)!")
}
+ if (schedulerBacklogTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.schedulerBacklogTimeout must be > 0!")
+ }
+ if (sustainedSchedulerBacklogTimeout <= 0) {
+ throw new SparkException(
+ "spark.dynamicAllocation.sustainedSchedulerBacklogTimeout must be > 0!")
+ }
+ if (executorIdleTimeout <= 0) {
+ throw new SparkException("spark.dynamicAllocation.executorIdleTimeout must be > 0!")
+ }
+ // Require external shuffle service for dynamic allocation
+ // Otherwise, we may lose shuffle files when killing executors
+ if (!conf.getBoolean("spark.shuffle.service.enabled", false) && !testing) {
+ throw new SparkException("Dynamic allocation of executors requires the external " +
+ "shuffle service. You may enable this through spark.shuffle.service.enabled.")
+ }
+ if (tasksPerExecutor == 0) {
+ throw new SparkException("spark.executor.cores must not be less than spark.task.cpus.cores")
+ }
}
/**
@@ -137,8 +173,7 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* Register for scheduler callbacks to decide when to add and remove executors.
*/
def start(): Unit = {
- val listener = new ExecutorAllocationListener(this)
- sc.addSparkListener(listener)
+ listenerBus.addListener(listener)
startPolling()
}
@@ -177,11 +212,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
addTime += sustainedSchedulerBacklogTimeout * 1000
}
- removeTimes.foreach { case (executorId, expireTime) =>
- if (now >= expireTime) {
+ removeTimes.retain { case (executorId, expireTime) =>
+ val expired = now >= expireTime
+ if (expired) {
removeExecutor(executorId)
- removeTimes.remove(executorId)
}
+ !expired
}
}
@@ -201,15 +237,29 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
return 0
}
- // Request executors with respect to the upper bound
- val actualNumExecutorsToAdd =
- if (numExistingExecutors + numExecutorsToAdd <= maxNumExecutors) {
- numExecutorsToAdd
- } else {
- maxNumExecutors - numExistingExecutors
- }
+ // The number of executors needed to satisfy all pending tasks is the number of tasks pending
+ // divided by the number of tasks each executor can fit, rounded up.
+ val maxNumExecutorsPending =
+ (listener.totalPendingTasks() + tasksPerExecutor - 1) / tasksPerExecutor
+ if (numExecutorsPending >= maxNumExecutorsPending) {
+ logDebug(s"Not adding executors because there are already $numExecutorsPending " +
+ s"pending and pending tasks could only fill $maxNumExecutorsPending")
+ numExecutorsToAdd = 1
+ return 0
+ }
+
+ // It's never useful to request more executors than could satisfy all the pending tasks, so
+ // cap request at that amount.
+ // Also cap request with respect to the configured upper bound.
+ val maxNumExecutorsToAdd = math.min(
+ maxNumExecutorsPending - numExecutorsPending,
+ maxNumExecutors - numExistingExecutors)
+ assert(maxNumExecutorsToAdd > 0)
+
+ val actualNumExecutorsToAdd = math.min(numExecutorsToAdd, maxNumExecutorsToAdd)
+
val newTotalExecutors = numExistingExecutors + actualNumExecutorsToAdd
- val addRequestAcknowledged = testing || sc.requestExecutors(actualNumExecutorsToAdd)
+ val addRequestAcknowledged = testing || client.requestExecutors(actualNumExecutorsToAdd)
if (addRequestAcknowledged) {
logInfo(s"Requesting $actualNumExecutorsToAdd new executor(s) because " +
s"tasks are backlogged (new desired total will be $newTotalExecutors)")
@@ -245,16 +295,16 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
// Do not kill the executor if we have already reached the lower bound
val numExistingExecutors = executorIds.size - executorsPendingToRemove.size
if (numExistingExecutors - 1 < minNumExecutors) {
- logInfo(s"Not removing idle executor $executorId because there are only " +
+ logDebug(s"Not removing idle executor $executorId because there are only " +
s"$numExistingExecutors executor(s) left (limit $minNumExecutors)")
return false
}
// Send a request to the backend to kill this executor
- val removeRequestAcknowledged = testing || sc.killExecutor(executorId)
+ val removeRequestAcknowledged = testing || client.killExecutor(executorId)
if (removeRequestAcknowledged) {
logInfo(s"Removing executor $executorId because it has been idle for " +
- s"$removeThresholdSeconds seconds (new desired total will be ${numExistingExecutors - 1})")
+ s"$executorIdleTimeout seconds (new desired total will be ${numExistingExecutors - 1})")
executorsPendingToRemove.add(executorId)
true
} else {
@@ -269,7 +319,11 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
private def onExecutorAdded(executorId: String): Unit = synchronized {
if (!executorIds.contains(executorId)) {
executorIds.add(executorId)
- executorIds.foreach(onExecutorIdle)
+ // If an executor (call this executor X) is not removed because the lower bound
+ // has been reached, it will no longer be marked as idle. When new executors join,
+ // however, we are no longer at the lower bound, and so we must mark executor X
+ // as idle again so as not to forget that it is a candidate for removal. (see SPARK-4951)
+ executorIds.filter(listener.isExecutorIdle).foreach(onExecutorIdle)
logInfo(s"New executor $executorId has registered (new total is ${executorIds.size})")
if (numExecutorsPending > 0) {
numExecutorsPending -= 1
@@ -327,10 +381,14 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* the executor is not already marked as idle.
*/
private def onExecutorIdle(executorId: String): Unit = synchronized {
- if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
- logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
- s"scheduled to run on the executor (to expire in $removeThresholdSeconds seconds)")
- removeTimes(executorId) = clock.getTimeMillis + removeThresholdSeconds * 1000
+ if (executorIds.contains(executorId)) {
+ if (!removeTimes.contains(executorId) && !executorsPendingToRemove.contains(executorId)) {
+ logDebug(s"Starting idle timer for $executorId because there are no more tasks " +
+ s"scheduled to run on the executor (to expire in $executorIdleTimeout seconds)")
+ removeTimes(executorId) = clock.getTimeMillis + executorIdleTimeout * 1000
+ }
+ } else {
+ logWarning(s"Attempted to mark unknown executor $executorId idle")
}
}
@@ -350,25 +408,24 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
* and consistency of events returned by the listener. For simplicity, it does not account
* for speculated tasks.
*/
- private class ExecutorAllocationListener(allocationManager: ExecutorAllocationManager)
- extends SparkListener {
+ private class ExecutorAllocationListener extends SparkListener {
private val stageIdToNumTasks = new mutable.HashMap[Int, Int]
private val stageIdToTaskIndices = new mutable.HashMap[Int, mutable.HashSet[Int]]
private val executorIdToTaskIds = new mutable.HashMap[String, mutable.HashSet[Long]]
override def onStageSubmitted(stageSubmitted: SparkListenerStageSubmitted): Unit = {
- synchronized {
- val stageId = stageSubmitted.stageInfo.stageId
- val numTasks = stageSubmitted.stageInfo.numTasks
+ val stageId = stageSubmitted.stageInfo.stageId
+ val numTasks = stageSubmitted.stageInfo.numTasks
+ allocationManager.synchronized {
stageIdToNumTasks(stageId) = numTasks
allocationManager.onSchedulerBacklogged()
}
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted): Unit = {
- synchronized {
- val stageId = stageCompleted.stageInfo.stageId
+ val stageId = stageCompleted.stageInfo.stageId
+ allocationManager.synchronized {
stageIdToNumTasks -= stageId
stageIdToTaskIndices -= stageId
@@ -380,39 +437,49 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
}
}
- override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = synchronized {
+ override def onTaskStart(taskStart: SparkListenerTaskStart): Unit = {
val stageId = taskStart.stageId
val taskId = taskStart.taskInfo.taskId
val taskIndex = taskStart.taskInfo.index
val executorId = taskStart.taskInfo.executorId
- // If this is the last pending task, mark the scheduler queue as empty
- stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
- val numTasksScheduled = stageIdToTaskIndices(stageId).size
- val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
- if (numTasksScheduled == numTasksTotal) {
- // No more pending tasks for this stage
- stageIdToNumTasks -= stageId
- if (stageIdToNumTasks.isEmpty) {
- allocationManager.onSchedulerQueueEmpty()
+ allocationManager.synchronized {
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
+
+ // If this is the last pending task, mark the scheduler queue as empty
+ stageIdToTaskIndices.getOrElseUpdate(stageId, new mutable.HashSet[Int]) += taskIndex
+ val numTasksScheduled = stageIdToTaskIndices(stageId).size
+ val numTasksTotal = stageIdToNumTasks.getOrElse(stageId, -1)
+ if (numTasksScheduled == numTasksTotal) {
+ // No more pending tasks for this stage
+ stageIdToNumTasks -= stageId
+ if (stageIdToNumTasks.isEmpty) {
+ allocationManager.onSchedulerQueueEmpty()
+ }
}
- }
- // Mark the executor on which this task is scheduled as busy
- executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
- allocationManager.onExecutorBusy(executorId)
+ // Mark the executor on which this task is scheduled as busy
+ executorIdToTaskIds.getOrElseUpdate(executorId, new mutable.HashSet[Long]) += taskId
+ allocationManager.onExecutorBusy(executorId)
+ }
}
- override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = synchronized {
+ override def onTaskEnd(taskEnd: SparkListenerTaskEnd): Unit = {
val executorId = taskEnd.taskInfo.executorId
val taskId = taskEnd.taskInfo.taskId
-
- // If the executor is no longer running scheduled any tasks, mark it as idle
- if (executorIdToTaskIds.contains(executorId)) {
- executorIdToTaskIds(executorId) -= taskId
- if (executorIdToTaskIds(executorId).isEmpty) {
- executorIdToTaskIds -= executorId
- allocationManager.onExecutorIdle(executorId)
+ allocationManager.synchronized {
+ // If the executor is no longer running scheduled any tasks, mark it as idle
+ if (executorIdToTaskIds.contains(executorId)) {
+ executorIdToTaskIds(executorId) -= taskId
+ if (executorIdToTaskIds(executorId).isEmpty) {
+ executorIdToTaskIds -= executorId
+ allocationManager.onExecutorIdle(executorId)
+ }
}
}
}
@@ -420,7 +487,12 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
override def onBlockManagerAdded(blockManagerAdded: SparkListenerBlockManagerAdded): Unit = {
val executorId = blockManagerAdded.blockManagerId.executorId
if (executorId != SparkContext.DRIVER_IDENTIFIER) {
- allocationManager.onExecutorAdded(executorId)
+ // This guards against the race condition in which the `SparkListenerTaskStart`
+ // event is posted before the `SparkListenerBlockManagerAdded` event, which is
+ // possible because these events are posted in different threads. (see SPARK-4951)
+ if (!allocationManager.executorIds.contains(executorId)) {
+ allocationManager.onExecutorAdded(executorId)
+ }
}
}
@@ -428,6 +500,27 @@ private[spark] class ExecutorAllocationManager(sc: SparkContext) extends Logging
blockManagerRemoved: SparkListenerBlockManagerRemoved): Unit = {
allocationManager.onExecutorRemoved(blockManagerRemoved.blockManagerId.executorId)
}
+
+ /**
+ * An estimate of the total number of pending tasks remaining for currently running stages. Does
+ * not account for tasks which may have failed and been resubmitted.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def totalPendingTasks(): Int = {
+ stageIdToNumTasks.map { case (stageId, numTasks) =>
+ numTasks - stageIdToTaskIndices.get(stageId).map(_.size).getOrElse(0)
+ }.sum
+ }
+
+ /**
+ * Return true if an executor is not currently running a task, and false otherwise.
+ *
+ * Note: This is not thread-safe without the caller owning the `allocationManager` lock.
+ */
+ def isExecutorIdle(executorId: String): Boolean = {
+ !executorIdToTaskIds.contains(executorId)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/HttpFileServer.scala b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
index edc3889c9ae51..3f33332a81eaf 100644
--- a/core/src/main/scala/org/apache/spark/HttpFileServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpFileServer.scala
@@ -24,6 +24,7 @@ import com.google.common.io.Files
import org.apache.spark.util.Utils
private[spark] class HttpFileServer(
+ conf: SparkConf,
securityManager: SecurityManager,
requestedPort: Int = 0)
extends Logging {
@@ -35,13 +36,13 @@ private[spark] class HttpFileServer(
var serverUri : String = null
def initialize() {
- baseDir = Utils.createTempDir()
+ baseDir = Utils.createTempDir(Utils.getLocalDir(conf), "httpd")
fileDir = new File(baseDir, "files")
jarDir = new File(baseDir, "jars")
fileDir.mkdir()
jarDir.mkdir()
logInfo("HTTP File server directory is " + baseDir)
- httpServer = new HttpServer(baseDir, securityManager, requestedPort, "HTTP file server")
+ httpServer = new HttpServer(conf, baseDir, securityManager, requestedPort, "HTTP file server")
httpServer.start()
serverUri = httpServer.uri
logDebug("HTTP file server started at: " + serverUri)
diff --git a/core/src/main/scala/org/apache/spark/HttpServer.scala b/core/src/main/scala/org/apache/spark/HttpServer.scala
index 912558d0cab7d..fa22787ce7ea3 100644
--- a/core/src/main/scala/org/apache/spark/HttpServer.scala
+++ b/core/src/main/scala/org/apache/spark/HttpServer.scala
@@ -42,6 +42,7 @@ private[spark] class ServerStateException(message: String) extends Exception(mes
* around a Jetty server.
*/
private[spark] class HttpServer(
+ conf: SparkConf,
resourceBase: File,
securityManager: SecurityManager,
requestedPort: Int = 0,
@@ -57,7 +58,7 @@ private[spark] class HttpServer(
} else {
logInfo("Starting HTTP Server")
val (actualServer, actualPort) =
- Utils.startServiceOnPort[Server](requestedPort, doStart, serverName)
+ Utils.startServiceOnPort[Server](requestedPort, doStart, conf, serverName)
server = actualServer
port = actualPort
}
diff --git a/core/src/main/scala/org/apache/spark/Logging.scala b/core/src/main/scala/org/apache/spark/Logging.scala
index d4f2624061e35..419d093d55643 100644
--- a/core/src/main/scala/org/apache/spark/Logging.scala
+++ b/core/src/main/scala/org/apache/spark/Logging.scala
@@ -118,15 +118,17 @@ trait Logging {
// org.slf4j.impl.Log4jLoggerFactory, from the log4j 2.0 binding, currently
// org.apache.logging.slf4j.Log4jLoggerFactory
val usingLog4j12 = "org.slf4j.impl.Log4jLoggerFactory".equals(binderClass)
- val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
- if (!log4j12Initialized && usingLog4j12) {
- val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
- Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
- case Some(url) =>
- PropertyConfigurator.configure(url)
- System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
- case None =>
- System.err.println(s"Spark was unable to load $defaultLogProps")
+ if (usingLog4j12) {
+ val log4j12Initialized = LogManager.getRootLogger.getAllAppenders.hasMoreElements
+ if (!log4j12Initialized) {
+ val defaultLogProps = "org/apache/spark/log4j-defaults.properties"
+ Option(Utils.getSparkClassLoader.getResource(defaultLogProps)) match {
+ case Some(url) =>
+ PropertyConfigurator.configure(url)
+ System.err.println(s"Using Spark's default log4j profile: $defaultLogProps")
+ case None =>
+ System.err.println(s"Spark was unable to load $defaultLogProps")
+ }
}
}
Logging.initialized = true
diff --git a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
index 7d96962c4acd7..a074ab8ece1b7 100644
--- a/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
+++ b/core/src/main/scala/org/apache/spark/MapOutputTracker.scala
@@ -72,7 +72,7 @@ private[spark] class MapOutputTrackerMasterActor(tracker: MapOutputTrackerMaster
/**
* Class that keeps track of the location of the map output of
* a stage. This is abstract because different versions of MapOutputTracker
- * (driver and worker) use different HashMap to store its metadata.
+ * (driver and executor) use different HashMap to store its metadata.
*/
private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging {
private val timeout = AkkaUtils.askTimeout(conf)
@@ -81,11 +81,11 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
var trackerActor: ActorRef = _
/**
- * This HashMap has different behavior for the master and the workers.
+ * This HashMap has different behavior for the driver and the executors.
*
- * On the master, it serves as the source of map outputs recorded from ShuffleMapTasks.
- * On the workers, it simply serves as a cache, in which a miss triggers a fetch from the
- * master's corresponding HashMap.
+ * On the driver, it serves as the source of map outputs recorded from ShuffleMapTasks.
+ * On the executors, it simply serves as a cache, in which a miss triggers a fetch from the
+ * driver's corresponding HashMap.
*
* Note: because mapStatuses is accessed concurrently, subclasses should make sure it's a
* thread-safe map.
@@ -99,7 +99,7 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
protected var epoch: Long = 0
protected val epochLock = new AnyRef
- /** Remembers which map output locations are currently being fetched on a worker. */
+ /** Remembers which map output locations are currently being fetched on an executor. */
private val fetching = new HashSet[Int]
/**
@@ -136,14 +136,12 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
logInfo("Don't have map outputs for shuffle " + shuffleId + ", fetching them")
var fetchedStatuses: Array[MapStatus] = null
fetching.synchronized {
- if (fetching.contains(shuffleId)) {
- // Someone else is fetching it; wait for them to be done
- while (fetching.contains(shuffleId)) {
- try {
- fetching.wait()
- } catch {
- case e: InterruptedException =>
- }
+ // Someone else is fetching it; wait for them to be done
+ while (fetching.contains(shuffleId)) {
+ try {
+ fetching.wait()
+ } catch {
+ case e: InterruptedException =>
}
}
@@ -198,8 +196,8 @@ private[spark] abstract class MapOutputTracker(conf: SparkConf) extends Logging
/**
* Called from executors to update the epoch number, potentially clearing old outputs
- * because of a fetch failure. Each worker task calls this with the latest epoch
- * number on the master at the time it was created.
+ * because of a fetch failure. Each executor task calls this with the latest epoch
+ * number on the driver at the time it was created.
*/
def updateEpoch(newEpoch: Long) {
epochLock.synchronized {
@@ -231,7 +229,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
private var cacheEpoch = epoch
/**
- * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the master,
+ * Timestamp based HashMap for storing mapStatuses and cached serialized statuses in the driver,
* so that statuses are dropped only by explicit de-registering or by TTL-based cleaning (if set).
* Other than these two scenarios, nothing should be dropped from this HashMap.
*/
@@ -341,7 +339,7 @@ private[spark] class MapOutputTrackerMaster(conf: SparkConf)
}
/**
- * MapOutputTracker for the workers, which fetches map output information from the driver's
+ * MapOutputTracker for the executors, which fetches map output information from the driver's
* MapOutputTrackerMaster.
*/
private[spark] class MapOutputTrackerWorker(conf: SparkConf) extends MapOutputTracker(conf) {
diff --git a/core/src/main/scala/org/apache/spark/Partition.scala b/core/src/main/scala/org/apache/spark/Partition.scala
index 27892dbd2a0bc..dd3f28e4197e3 100644
--- a/core/src/main/scala/org/apache/spark/Partition.scala
+++ b/core/src/main/scala/org/apache/spark/Partition.scala
@@ -18,11 +18,11 @@
package org.apache.spark
/**
- * A partition of an RDD.
+ * An identifier for a partition in an RDD.
*/
trait Partition extends Serializable {
/**
- * Get the split's index within its parent RDD
+ * Get the partition's index within its parent RDD
*/
def index: Int
diff --git a/core/src/main/scala/org/apache/spark/SecurityManager.scala b/core/src/main/scala/org/apache/spark/SecurityManager.scala
index 0e0f1a7b2377e..49dae5231a92c 100644
--- a/core/src/main/scala/org/apache/spark/SecurityManager.scala
+++ b/core/src/main/scala/org/apache/spark/SecurityManager.scala
@@ -22,6 +22,7 @@ import java.net.{Authenticator, PasswordAuthentication}
import org.apache.hadoop.io.Text
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.network.sasl.SecretKeyHolder
/**
* Spark class responsible for security.
@@ -84,7 +85,7 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Authenticator installed in the SecurityManager to how it does the authentication
* and in this case gets the user name and password from the request.
*
- * - ConnectionManager -> The Spark ConnectionManager uses java nio to asynchronously
+ * - BlockTransferService -> The Spark BlockTransferServices uses java nio to asynchronously
* exchange messages. For this we use the Java SASL
* (Simple Authentication and Security Layer) API and again use DIGEST-MD5
* as the authentication mechanism. This means the shared secret is not passed
@@ -92,31 +93,35 @@ import org.apache.spark.deploy.SparkHadoopUtil
* Note that SASL is pluggable as to what mechanism it uses. We currently use
* DIGEST-MD5 but this could be changed to use Kerberos or other in the future.
* Spark currently supports "auth" for the quality of protection, which means
- * the connection is not supporting integrity or privacy protection (encryption)
+ * the connection does not support integrity or privacy protection (encryption)
* after authentication. SASL also supports "auth-int" and "auth-conf" which
- * SPARK could be support in the future to allow the user to specify the quality
+ * SPARK could support in the future to allow the user to specify the quality
* of protection they want. If we support those, the messages will also have to
* be wrapped and unwrapped via the SaslServer/SaslClient.wrap/unwrap API's.
*
- * Since the connectionManager does asynchronous messages passing, the SASL
+ * Since the NioBlockTransferService does asynchronous messages passing, the SASL
* authentication is a bit more complex. A ConnectionManager can be both a client
- * and a Server, so for a particular connection is has to determine what to do.
+ * and a Server, so for a particular connection it has to determine what to do.
* A ConnectionId was added to be able to track connections and is used to
* match up incoming messages with connections waiting for authentication.
- * The ConnectionManager tracks all the sendingConnections using the ConnectionId
- * and waits for the response from the server and does the handshake before sending
+ * The ConnectionManager tracks all the sendingConnections using the ConnectionId,
+ * waits for the response from the server, and does the handshake before sending
* the real message.
*
+ * The NettyBlockTransferService ensures that SASL authentication is performed
+ * synchronously prior to any other communication on a connection. This is done in
+ * SaslClientBootstrap on the client side and SaslRpcHandler on the server side.
+ *
* - HTTP for the Spark UI -> the UI was changed to use servlets so that javax servlet filters
* can be used. Yarn requires a specific AmIpFilter be installed for security to work
- * properly. For non-Yarn deployments, users can write a filter to go through a
- * companies normal login service. If an authentication filter is in place then the
+ * properly. For non-Yarn deployments, users can write a filter to go through their
+ * organization's normal login service. If an authentication filter is in place then the
* SparkUI can be configured to check the logged in user against the list of users who
* have view acls to see if that user is authorized.
* The filters can also be used for many different purposes. For instance filters
* could be used for logging, encryption, or compression.
*
- * The exact mechanisms used to generate/distributed the shared secret is deployment specific.
+ * The exact mechanisms used to generate/distribute the shared secret are deployment-specific.
*
* For Yarn deployments, the secret is automatically generated using the Akka remote
* Crypt.generateSecureCookie() API. The secret is placed in the Hadoop UGI which gets passed
@@ -133,13 +138,13 @@ import org.apache.spark.deploy.SparkHadoopUtil
* All the nodes (Master and Workers) and the applications need to have the same shared secret.
* This again is not ideal as one user could potentially affect another users application.
* This should be enhanced in the future to provide better protection.
- * If the UI needs to be secured the user needs to install a javax servlet filter to do the
+ * If the UI needs to be secure, the user needs to install a javax servlet filter to do the
* authentication. Spark will then use that user to compare against the view acls to do
* authorization. If not filter is in place the user is generally null and no authorization
* can take place.
*/
-private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
+private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging with SecretKeyHolder {
// key used to store the spark secret in the Hadoop UGI
private val sparkSecretLookupKey = "sparkCookie"
@@ -337,4 +342,8 @@ private[spark] class SecurityManager(sparkConf: SparkConf) extends Logging {
* @return the secret key as a String if authentication is enabled, otherwise returns null
*/
def getSecretKey(): String = secretKey
+
+ // Default SecurityManager only has a single secret key, so ignore appId.
+ override def getSaslUser(appId: String): String = getSaslUser()
+ override def getSecretKey(appId: String): String = getSecretKey()
}
diff --git a/core/src/main/scala/org/apache/spark/SparkConf.scala b/core/src/main/scala/org/apache/spark/SparkConf.scala
index ad0a9017afead..3337974978ca4 100644
--- a/core/src/main/scala/org/apache/spark/SparkConf.scala
+++ b/core/src/main/scala/org/apache/spark/SparkConf.scala
@@ -17,8 +17,11 @@
package org.apache.spark
+import java.util.concurrent.ConcurrentHashMap
+
import scala.collection.JavaConverters._
-import scala.collection.mutable.{HashMap, LinkedHashSet}
+import scala.collection.mutable.LinkedHashSet
+
import org.apache.spark.serializer.KryoSerializer
/**
@@ -46,12 +49,12 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Create a SparkConf that loads defaults from system properties and the classpath */
def this() = this(true)
- private[spark] val settings = new HashMap[String, String]()
+ private val settings = new ConcurrentHashMap[String, String]()
if (loadDefaults) {
// Load any spark.* system properties
for ((k, v) <- System.getProperties.asScala if k.startsWith("spark.")) {
- settings(k) = v
+ set(k, v)
}
}
@@ -63,7 +66,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
if (value == null) {
throw new NullPointerException("null value")
}
- settings(key) = value
+ settings.put(key, value)
this
}
@@ -129,15 +132,13 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Set multiple parameters together */
def setAll(settings: Traversable[(String, String)]) = {
- this.settings ++= settings
+ this.settings.putAll(settings.toMap.asJava)
this
}
/** Set a parameter if it isn't already configured */
def setIfMissing(key: String, value: String): SparkConf = {
- if (!settings.contains(key)) {
- settings(key) = value
- }
+ settings.putIfAbsent(key, value)
this
}
@@ -163,21 +164,23 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Get a parameter; throws a NoSuchElementException if it's not set */
def get(key: String): String = {
- settings.getOrElse(key, throw new NoSuchElementException(key))
+ getOption(key).getOrElse(throw new NoSuchElementException(key))
}
/** Get a parameter, falling back to a default if not set */
def get(key: String, defaultValue: String): String = {
- settings.getOrElse(key, defaultValue)
+ getOption(key).getOrElse(defaultValue)
}
/** Get a parameter as an Option */
def getOption(key: String): Option[String] = {
- settings.get(key)
+ Option(settings.get(key))
}
/** Get all parameters as a list of pairs */
- def getAll: Array[(String, String)] = settings.clone().toArray
+ def getAll: Array[(String, String)] = {
+ settings.entrySet().asScala.map(x => (x.getKey, x.getValue)).toArray
+ }
/** Get a parameter as an integer, falling back to a default if not set */
def getInt(key: String, defaultValue: Int): Int = {
@@ -217,12 +220,18 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
*/
getAll.filter { case (k, _) => isAkkaConf(k) }
+ /**
+ * Returns the Spark application id, valid in the Driver after TaskScheduler registration and
+ * from the start in the Executor.
+ */
+ def getAppId: String = get("spark.app.id")
+
/** Does the configuration contain a given parameter? */
- def contains(key: String): Boolean = settings.contains(key)
+ def contains(key: String): Boolean = settings.containsKey(key)
/** Copy this object */
override def clone: SparkConf = {
- new SparkConf(false).setAll(settings)
+ new SparkConf(false).setAll(getAll)
}
/**
@@ -234,7 +243,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
/** Checks for illegal or deprecated config settings. Throws an exception for the former. Not
* idempotent - may mutate this conf object to convert deprecated settings to supported ones. */
private[spark] def validateSettings() {
- if (settings.contains("spark.local.dir")) {
+ if (contains("spark.local.dir")) {
val msg = "In Spark 1.0 and later spark.local.dir will be overridden by the value set by " +
"the cluster manager (via SPARK_LOCAL_DIRS in mesos/standalone and LOCAL_DIRS in YARN)."
logWarning(msg)
@@ -259,7 +268,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
}
// Validate spark.executor.extraJavaOptions
- settings.get(executorOptsKey).map { javaOpts =>
+ getOption(executorOptsKey).map { javaOpts =>
if (javaOpts.contains("-Dspark")) {
val msg = s"$executorOptsKey is not allowed to set Spark options (was '$javaOpts'). " +
"Set them directly on a SparkConf or in a properties file when using ./bin/spark-submit."
@@ -339,7 +348,7 @@ class SparkConf(loadDefaults: Boolean) extends Cloneable with Logging {
* configuration out for debugging.
*/
def toDebugString: String = {
- settings.toArray.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
+ getAll.sorted.map{case (k, v) => k + "=" + v}.mkString("\n")
}
}
@@ -364,7 +373,9 @@ private[spark] object SparkConf {
}
/**
- * Return whether the given config is a Spark port config.
+ * Return true if the given config matches either `spark.*.port` or `spark.port.*`.
*/
- def isSparkPortConf(name: String): Boolean = name.startsWith("spark.") && name.endsWith(".port")
+ def isSparkPortConf(name: String): Boolean = {
+ (name.startsWith("spark.") && name.endsWith(".port")) || name.startsWith("spark.port.")
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/SparkContext.scala b/core/src/main/scala/org/apache/spark/SparkContext.scala
index 8b4db783979ec..b50a54126ea3a 100644
--- a/core/src/main/scala/org/apache/spark/SparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/SparkContext.scala
@@ -21,11 +21,11 @@ import scala.language.implicitConversions
import java.io._
import java.net.URI
-import java.util.Arrays
+import java.util.{Arrays, Properties, UUID}
import java.util.concurrent.atomic.AtomicInteger
-import java.util.{Properties, UUID}
import java.util.UUID.randomUUID
import scala.collection.{Map, Set}
+import scala.collection.JavaConversions._
import scala.collection.generic.Growable
import scala.collection.mutable.HashMap
import scala.reflect.{ClassTag, classTag}
@@ -41,6 +41,7 @@ import akka.actor.Props
import org.apache.spark.annotation.{DeveloperApi, Experimental}
import org.apache.spark.broadcast.Broadcast
import org.apache.spark.deploy.{LocalSparkCluster, SparkHadoopUtil}
+import org.apache.spark.executor.TriggerThreadDump
import org.apache.spark.input.{StreamInputFormat, PortableDataStream, WholeTextFileInputFormat, FixedLengthBinaryInputFormat}
import org.apache.spark.partial.{ApproximateEvaluator, PartialResult}
import org.apache.spark.rdd._
@@ -49,25 +50,49 @@ import org.apache.spark.scheduler.cluster.{CoarseGrainedSchedulerBackend, SparkD
import org.apache.spark.scheduler.cluster.mesos.{CoarseMesosSchedulerBackend, MesosSchedulerBackend}
import org.apache.spark.scheduler.local.LocalBackend
import org.apache.spark.storage._
-import org.apache.spark.ui.SparkUI
+import org.apache.spark.ui.{SparkUI, ConsoleProgressBar}
import org.apache.spark.ui.jobs.JobProgressListener
-import org.apache.spark.util.{CallSite, ClosureCleaner, MetadataCleaner, MetadataCleanerType, TimeStampedWeakValueHashMap, Utils}
+import org.apache.spark.util._
/**
* Main entry point for Spark functionality. A SparkContext represents the connection to a Spark
* cluster, and can be used to create RDDs, accumulators and broadcast variables on that cluster.
*
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
+ *
* @param config a Spark Config object describing the application configuration. Any settings in
* this config overrides the default configs as well as system properties.
*/
+class SparkContext(config: SparkConf) extends Logging with ExecutorAllocationClient {
+
+ // The call site where this SparkContext was constructed.
+ private val creationSite: CallSite = Utils.getCallSite()
-class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
+ // If true, log warnings instead of throwing exceptions when multiple SparkContexts are active
+ private val allowMultipleContexts: Boolean =
+ config.getBoolean("spark.driver.allowMultipleContexts", false)
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having started construction.
+ // NOTE: this must be placed at the beginning of the SparkContext constructor.
+ SparkContext.markPartiallyConstructed(this, allowMultipleContexts)
// This is used only by YARN for now, but should be relevant to other cluster types (Mesos,
// etc) too. This is typically generated from InputFormatInfo.computePreferredLocations. It
// contains a map from hostname to a list of input format splits on the host.
private[spark] var preferredNodeLocationData: Map[String, Set[SplitInfo]] = Map()
+ val startTime = System.currentTimeMillis()
+
+ @volatile private var stopped: Boolean = false
+
+ private def assertNotStopped(): Unit = {
+ if (stopped) {
+ throw new IllegalStateException("Cannot call methods on a stopped SparkContext")
+ }
+ }
+
/**
* Create a SparkContext that loads settings from system properties (for instance, when
* launching with ./bin/spark-submit).
@@ -228,6 +253,15 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
private[spark] val jobProgressListener = new JobProgressListener(conf)
listenerBus.addListener(jobProgressListener)
+ val statusTracker = new SparkStatusTracker(this)
+
+ private[spark] val progressBar: Option[ConsoleProgressBar] =
+ if (conf.getBoolean("spark.ui.showConsoleProgress", true) && !log.isInfoEnabled) {
+ Some(new ConsoleProgressBar(this))
+ } else {
+ None
+ }
+
// Initialize the Spark UI
private[spark] val ui: Option[SparkUI] =
if (conf.getBoolean("spark.ui.enabled", true)) {
@@ -245,8 +279,6 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
/** A default Hadoop Configuration for the Hadoop code (e.g. file systems) that we reuse. */
val hadoopConfiguration = SparkHadoopUtil.get.newConfiguration(conf)
- val startTime = System.currentTimeMillis()
-
// Add each JAR given through the constructor
if (jars != null) {
jars.foreach(addJar)
@@ -302,8 +334,13 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
try {
dagScheduler = new DAGScheduler(this)
} catch {
- case e: Exception => throw
- new SparkException("DAGScheduler cannot be initialized due to %s".format(e.getMessage))
+ case e: Exception => {
+ try {
+ stop()
+ } finally {
+ throw new SparkException("Error while constructing DAGScheduler", e)
+ }
+ }
}
// start TaskScheduler after taskScheduler sets DAGScheduler reference in DAGScheduler's
@@ -313,11 +350,15 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
val applicationId: String = taskScheduler.applicationId()
conf.set("spark.app.id", applicationId)
+ env.blockManager.initialize(applicationId)
+
val metricsSystem = env.metricsSystem
// The metrics system for Driver need to be set spark.app.id to app ID.
// So it should start after we get app ID from the task scheduler and set spark.app.id.
metricsSystem.start()
+ // Attach the driver metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(handler => ui.foreach(_.attachHandler(handler)))
// Optionally log Spark events
private[spark] val eventLogger: Option[EventLoggingListener] = {
@@ -333,7 +374,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
// Optionally scale number of executors dynamically based on workload. Exposed for testing.
private[spark] val executorAllocationManager: Option[ExecutorAllocationManager] =
if (conf.getBoolean("spark.dynamicAllocation.enabled", false)) {
- Some(new ExecutorAllocationManager(this))
+ Some(new ExecutorAllocationManager(this, listenerBus, conf))
} else {
None
}
@@ -361,6 +402,29 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
override protected def childValue(parent: Properties): Properties = new Properties(parent)
}
+ /**
+ * Called by the web UI to obtain executor thread dumps. This method may be expensive.
+ * Logs an error and returns None if we failed to obtain a thread dump, which could occur due
+ * to an executor being dead or unresponsive or due to network issues while sending the thread
+ * dump message back to the driver.
+ */
+ private[spark] def getExecutorThreadDump(executorId: String): Option[Array[ThreadStackTrace]] = {
+ try {
+ if (executorId == SparkContext.DRIVER_IDENTIFIER) {
+ Some(Utils.getThreadDump())
+ } else {
+ val (host, port) = env.blockManager.master.getActorSystemHostPortForExecutor(executorId).get
+ val actorRef = AkkaUtils.makeExecutorRef("ExecutorActor", conf, host, port, env.actorSystem)
+ Some(AkkaUtils.askWithReply[Array[ThreadStackTrace]](TriggerThreadDump, actorRef,
+ AkkaUtils.numRetries(conf), AkkaUtils.retryWaitMs(conf), AkkaUtils.askTimeout(conf)))
+ }
+ } catch {
+ case e: Exception =>
+ logError(s"Exception getting thread dump from executor $executorId", e)
+ None
+ }
+ }
+
private[spark] def getLocalProperties: Properties = localProperties.get()
private[spark] def setLocalProperties(props: Properties) {
@@ -458,12 +522,12 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
/** Distribute a local Scala collection to form an RDD.
*
- * @note Parallelize acts lazily. If `seq` is a mutable collection and is
- * altered after the call to parallelize and before the first action on the
- * RDD, the resultant RDD will reflect the modified collection. Pass a copy of
- * the argument to avoid this.
+ * @note Parallelize acts lazily. If `seq` is a mutable collection and is altered after the call
+ * to parallelize and before the first action on the RDD, the resultant RDD will reflect the
+ * modified collection. Pass a copy of the argument to avoid this.
*/
def parallelize[T: ClassTag](seq: Seq[T], numSlices: Int = defaultParallelism): RDD[T] = {
+ assertNotStopped()
new ParallelCollectionRDD[T](this, seq, numSlices, Map[Int, Seq[String]]())
}
@@ -479,6 +543,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* location preferences (hostnames of Spark nodes) for each object.
* Create a new partition for each collection item. */
def makeRDD[T: ClassTag](seq: Seq[(T, Seq[String])]): RDD[T] = {
+ assertNotStopped()
val indexToPrefs = seq.zipWithIndex.map(t => (t._2, t._1._2)).toMap
new ParallelCollectionRDD[T](this, seq.map(_._1), seq.size, indexToPrefs)
}
@@ -488,6 +553,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* Hadoop-supported file system URI, and return it as an RDD of Strings.
*/
def textFile(path: String, minPartitions: Int = defaultMinPartitions): RDD[String] = {
+ assertNotStopped()
hadoopFile(path, classOf[TextInputFormat], classOf[LongWritable], classOf[Text],
minPartitions).map(pair => pair._2.toString).setName(path)
}
@@ -521,6 +587,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
*/
def wholeTextFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, String)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -535,6 +602,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
/**
+ * :: Experimental ::
+ *
* Get an RDD for a Hadoop-readable dataset as PortableDataStream for each file
* (useful for binary data)
*
@@ -564,6 +633,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
@Experimental
def binaryFiles(path: String, minPartitions: Int = defaultMinPartitions):
RDD[(String, PortableDataStream)] = {
+ assertNotStopped()
val job = new NewHadoopJob(hadoopConfiguration)
NewFileInputFormat.addInputPath(job, new Path(path))
val updateConf = job.getConfiguration
@@ -577,6 +647,8 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
}
/**
+ * :: Experimental ::
+ *
* Load data from a flat binary file, assuming the length of each record is constant.
*
* @param path Directory to the input data files
@@ -586,6 +658,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
@Experimental
def binaryRecords(path: String, recordLength: Int, conf: Configuration = hadoopConfiguration)
: RDD[Array[Byte]] = {
+ assertNotStopped()
conf.setInt(FixedLengthBinaryInputFormat.RECORD_LENGTH_PROPERTY, recordLength)
val br = newAPIHadoopFile[LongWritable, BytesWritable, FixedLengthBinaryInputFormat](path,
classOf[FixedLengthBinaryInputFormat],
@@ -619,6 +692,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// Add necessary security credentials to the JobConf before broadcasting it.
SparkHadoopUtil.get.addCredentials(conf)
new HadoopRDD(this, conf, inputFormatClass, keyClass, valueClass, minPartitions)
@@ -638,6 +712,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
valueClass: Class[V],
minPartitions: Int = defaultMinPartitions
): RDD[(K, V)] = {
+ assertNotStopped()
// A Hadoop configuration can be about 10 KB, which is pretty big, so broadcast it.
val confBroadcast = broadcast(new SerializableWritable(hadoopConfiguration))
val setInputPathsFunc = (jobConf: JobConf) => FileInputFormat.setInputPaths(jobConf, path)
@@ -717,6 +792,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
kClass: Class[K],
vClass: Class[V],
conf: Configuration = hadoopConfiguration): RDD[(K, V)] = {
+ assertNotStopped()
val job = new NewHadoopJob(conf)
NewFileInputFormat.addInputPath(job, new Path(path))
val updatedConf = job.getConfiguration
@@ -737,6 +813,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
fClass: Class[F],
kClass: Class[K],
vClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
new NewHadoopRDD(this, fClass, kClass, vClass, conf)
}
@@ -752,6 +829,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
valueClass: Class[V],
minPartitions: Int
): RDD[(K, V)] = {
+ assertNotStopped()
val inputFormatClass = classOf[SequenceFileInputFormat[K, V]]
hadoopFile(path, inputFormatClass, keyClass, valueClass, minPartitions)
}
@@ -763,9 +841,10 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* If you plan to directly cache Hadoop writable objects, you should first copy them using
* a `map` function.
* */
- def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]
- ): RDD[(K, V)] =
+ def sequenceFile[K, V](path: String, keyClass: Class[K], valueClass: Class[V]): RDD[(K, V)] = {
+ assertNotStopped()
sequenceFile(path, keyClass, valueClass, defaultMinPartitions)
+ }
/**
* Version of sequenceFile() for types implicitly convertible to Writables through a
@@ -793,6 +872,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
(implicit km: ClassTag[K], vm: ClassTag[V],
kcf: () => WritableConverter[K], vcf: () => WritableConverter[V])
: RDD[(K, V)] = {
+ assertNotStopped()
val kc = kcf()
val vc = vcf()
val format = classOf[SequenceFileInputFormat[Writable, Writable]]
@@ -814,6 +894,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
path: String,
minPartitions: Int = defaultMinPartitions
): RDD[T] = {
+ assertNotStopped()
sequenceFile(path, classOf[NullWritable], classOf[BytesWritable], minPartitions)
.flatMap(x => Utils.deserialize[Array[T]](x._2.getBytes, Utils.getContextOrSparkClassLoader))
}
@@ -889,6 +970,13 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* The variable will be sent to each cluster only once.
*/
def broadcast[T: ClassTag](value: T): Broadcast[T] = {
+ assertNotStopped()
+ if (classOf[RDD[_]].isAssignableFrom(classTag[T].runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have created RDD broadcast variables but not used them:
+ logWarning("Can not directly broadcast RDDs; instead, call collect() and "
+ + "broadcast the result (see SPARK-5063)")
+ }
val bc = env.broadcastManager.newBroadcast[T](value, isLocal)
val callSite = getCallSite
logInfo("Created broadcast " + bc.id + " from " + callSite.shortForm)
@@ -935,7 +1023,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def requestExecutors(numAdditionalExecutors: Int): Boolean = {
+ override def requestExecutors(numAdditionalExecutors: Int): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.requestExecutors(numAdditionalExecutors)
@@ -951,7 +1039,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def killExecutors(executorIds: Seq[String]): Boolean = {
+ override def killExecutors(executorIds: Seq[String]): Boolean = {
schedulerBackend match {
case b: CoarseGrainedSchedulerBackend =>
b.killExecutors(executorIds)
@@ -967,11 +1055,80 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* This is currently only supported in Yarn mode. Return whether the request is received.
*/
@DeveloperApi
- def killExecutor(executorId: String): Boolean = killExecutors(Seq(executorId))
+ override def killExecutor(executorId: String): Boolean = super.killExecutor(executorId)
/** The version of Spark on which this application is running. */
def version = SPARK_VERSION
+ /**
+ * Return a map from the slave to the max memory available for caching and the remaining
+ * memory available for caching.
+ */
+ def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
+ assertNotStopped()
+ env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
+ (blockManagerId.host + ":" + blockManagerId.port, mem)
+ }
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return information about what RDDs are cached, if they are in mem or on disk, how much space
+ * they take, etc.
+ */
+ @DeveloperApi
+ def getRDDStorageInfo: Array[RDDInfo] = {
+ assertNotStopped()
+ val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
+ StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
+ rddInfos.filter(_.isCached)
+ }
+
+ /**
+ * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
+ * Note that this does not necessarily mean the caching or computation was successful.
+ */
+ def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
+
+ /**
+ * :: DeveloperApi ::
+ * Return information about blocks stored in all of the slaves
+ */
+ @DeveloperApi
+ def getExecutorStorageStatus: Array[StorageStatus] = {
+ assertNotStopped()
+ env.blockManager.master.getStorageStatus
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return pools for fair scheduler
+ */
+ @DeveloperApi
+ def getAllPools: Seq[Schedulable] = {
+ assertNotStopped()
+ // TODO(xiajunluan): We should take nested pools into account
+ taskScheduler.rootPool.schedulableQueue.toSeq
+ }
+
+ /**
+ * :: DeveloperApi ::
+ * Return the pool associated with the given name, if one exists
+ */
+ @DeveloperApi
+ def getPoolForName(pool: String): Option[Schedulable] = {
+ assertNotStopped()
+ Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
+ }
+
+ /**
+ * Return current scheduling mode
+ */
+ def getSchedulingMode: SchedulingMode.SchedulingMode = {
+ assertNotStopped()
+ taskScheduler.schedulingMode
+ }
+
/**
* Clear the job's list of files added by `addFile` so that they do not get downloaded to
* any new nodes.
@@ -1071,27 +1228,28 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
/** Shut down the SparkContext. */
def stop() {
- postApplicationEnd()
- ui.foreach(_.stop())
- // Do this only if not stopped already - best case effort.
- // prevent NPE if stopped more than once.
- val dagSchedulerCopy = dagScheduler
- dagScheduler = null
- if (dagSchedulerCopy != null) {
- env.metricsSystem.report()
- metadataCleaner.cancel()
- env.actorSystem.stop(heartbeatReceiver)
- cleaner.foreach(_.stop())
- dagSchedulerCopy.stop()
- taskScheduler = null
- // TODO: Cache.stop()?
- env.stop()
- SparkEnv.set(null)
- listenerBus.stop()
- eventLogger.foreach(_.stop())
- logInfo("Successfully stopped SparkContext")
- } else {
- logInfo("SparkContext already stopped")
+ SparkContext.SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ postApplicationEnd()
+ ui.foreach(_.stop())
+ if (!stopped) {
+ stopped = true
+ env.metricsSystem.report()
+ metadataCleaner.cancel()
+ env.actorSystem.stop(heartbeatReceiver)
+ cleaner.foreach(_.stop())
+ dagScheduler.stop()
+ dagScheduler = null
+ taskScheduler = null
+ // TODO: Cache.stop()?
+ env.stop()
+ SparkEnv.set(null)
+ listenerBus.stop()
+ eventLogger.foreach(_.stop())
+ logInfo("Successfully stopped SparkContext")
+ SparkContext.clearActiveContext()
+ } else {
+ logInfo("SparkContext already stopped")
+ }
}
}
@@ -1154,14 +1312,15 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
partitions: Seq[Int],
allowLocal: Boolean,
resultHandler: (Int, U) => Unit) {
- if (dagScheduler == null) {
- throw new SparkException("SparkContext has been shutdown")
+ if (stopped) {
+ throw new IllegalStateException("SparkContext has been shutdown")
}
val callSite = getCallSite
val cleanedFunc = clean(func)
logInfo("Starting job: " + callSite.shortForm)
dagScheduler.runJob(rdd, cleanedFunc, partitions, callSite, allowLocal,
resultHandler, localProperties.get)
+ progressBar.foreach(_.finishAll())
rdd.doCheckpoint()
}
@@ -1241,6 +1400,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
func: (TaskContext, Iterator[T]) => U,
evaluator: ApproximateEvaluator[U, R],
timeout: Long): PartialResult[R] = {
+ assertNotStopped()
val callSite = getCallSite
logInfo("Starting job: " + callSite.shortForm)
val start = System.nanoTime
@@ -1263,6 +1423,7 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
resultHandler: (Int, U) => Unit,
resultFunc: => R): SimpleFutureAction[R] =
{
+ assertNotStopped()
val cleanF = clean(processPartition)
val callSite = getCallSite
val waiter = dagScheduler.submitJob(
@@ -1281,11 +1442,13 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
* for more information.
*/
def cancelJobGroup(groupId: String) {
+ assertNotStopped()
dagScheduler.cancelJobGroup(groupId)
}
/** Cancel all jobs that have been scheduled or are running. */
def cancelAllJobs() {
+ assertNotStopped()
dagScheduler.cancelAllJobs()
}
@@ -1332,7 +1495,10 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
def getCheckpointDir = checkpointDir
/** Default level of parallelism to use when not given by user (e.g. parallelize and makeRDD). */
- def defaultParallelism: Int = taskScheduler.defaultParallelism
+ def defaultParallelism: Int = {
+ assertNotStopped()
+ taskScheduler.defaultParallelism
+ }
/** Default min number of partitions for Hadoop RDDs when not given by user */
@deprecated("use defaultMinPartitions", "1.0.0")
@@ -1380,6 +1546,11 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
private[spark] def cleanup(cleanupTime: Long) {
persistentRdds.clearOldValues(cleanupTime)
}
+
+ // In order to prevent multiple SparkContexts from being active at the same time, mark this
+ // context as having finished construction.
+ // NOTE: this must be placed at the end of the SparkContext constructor.
+ SparkContext.setActiveContext(this, allowMultipleContexts)
}
/**
@@ -1388,6 +1559,107 @@ class SparkContext(config: SparkConf) extends SparkStatusAPI with Logging {
*/
object SparkContext extends Logging {
+ /**
+ * Lock that guards access to global variables that track SparkContext construction.
+ */
+ private val SPARK_CONTEXT_CONSTRUCTOR_LOCK = new Object()
+
+ /**
+ * The active, fully-constructed SparkContext. If no SparkContext is active, then this is `None`.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var activeContext: Option[SparkContext] = None
+
+ /**
+ * Points to a partially-constructed SparkContext if some thread is in the SparkContext
+ * constructor, or `None` if no SparkContext is being constructed.
+ *
+ * Access to this field is guarded by SPARK_CONTEXT_CONSTRUCTOR_LOCK
+ */
+ private var contextBeingConstructed: Option[SparkContext] = None
+
+ /**
+ * Called to ensure that no other SparkContext is running in this JVM.
+ *
+ * Throws an exception if a running context is detected and logs a warning if another thread is
+ * constructing a SparkContext. This warning is necessary because the current locking scheme
+ * prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private def assertNoOtherContextIsRunning(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ contextBeingConstructed.foreach { otherContext =>
+ if (otherContext ne sc) { // checks for reference equality
+ // Since otherContext might point to a partially-constructed context, guard against
+ // its creationSite field being null:
+ val otherContextCreationSite =
+ Option(otherContext.creationSite).map(_.longForm).getOrElse("unknown location")
+ val warnMsg = "Another SparkContext is being constructed (or threw an exception in its" +
+ " constructor). This may indicate an error, since only one SparkContext may be" +
+ " running in this JVM (see SPARK-2243)." +
+ s" The other SparkContext was created at:\n$otherContextCreationSite"
+ logWarning(warnMsg)
+ }
+
+ activeContext.foreach { ctx =>
+ val errMsg = "Only one SparkContext may be running in this JVM (see SPARK-2243)." +
+ " To ignore this error, set spark.driver.allowMultipleContexts = true. " +
+ s"The currently running SparkContext was created at:\n${ctx.creationSite.longForm}"
+ val exception = new SparkException(errMsg)
+ if (allowMultipleContexts) {
+ logWarning("Multiple running SparkContexts detected in the same JVM!", exception)
+ } else {
+ throw exception
+ }
+ }
+ }
+ }
+ }
+
+ /**
+ * Called at the beginning of the SparkContext constructor to ensure that no SparkContext is
+ * running. Throws an exception if a running context is detected and logs a warning if another
+ * thread is constructing a SparkContext. This warning is necessary because the current locking
+ * scheme prevents us from reliably distinguishing between cases where another context is being
+ * constructed and cases where another constructor threw an exception.
+ */
+ private[spark] def markPartiallyConstructed(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = Some(sc)
+ }
+ }
+
+ /**
+ * Called at the end of the SparkContext constructor to ensure that no other SparkContext has
+ * raced with this constructor and started.
+ */
+ private[spark] def setActiveContext(
+ sc: SparkContext,
+ allowMultipleContexts: Boolean): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ assertNoOtherContextIsRunning(sc, allowMultipleContexts)
+ contextBeingConstructed = None
+ activeContext = Some(sc)
+ }
+ }
+
+ /**
+ * Clears the active SparkContext metadata. This is called by `SparkContext#stop()`. It's
+ * also called in unit tests to prevent a flood of warnings from test suites that don't / can't
+ * properly clean up their SparkContexts.
+ */
+ private[spark] def clearActiveContext(): Unit = {
+ SPARK_CONTEXT_CONSTRUCTOR_LOCK.synchronized {
+ activeContext = None
+ }
+ }
+
private[spark] val SPARK_JOB_DESCRIPTION = "spark.job.description"
private[spark] val SPARK_JOB_GROUP_ID = "spark.jobGroup.id"
@@ -1587,6 +1859,9 @@ object SparkContext extends Logging {
def localCpuCount = Runtime.getRuntime.availableProcessors()
// local[*] estimates the number of cores on the machine; local[N] uses exactly N threads.
val threadCount = if (threads == "*") localCpuCount else threads.toInt
+ if (threadCount <= 0) {
+ throw new SparkException(s"Asked to run locally with $threadCount threads")
+ }
val scheduler = new TaskSchedulerImpl(sc, MAX_LOCAL_TASK_FAILURES, isLocal = true)
val backend = new LocalBackend(scheduler, threadCount)
scheduler.initialize(backend)
diff --git a/core/src/main/scala/org/apache/spark/SparkEnv.scala b/core/src/main/scala/org/apache/spark/SparkEnv.scala
index e2f13accdfab5..5d465c567ba81 100644
--- a/core/src/main/scala/org/apache/spark/SparkEnv.scala
+++ b/core/src/main/scala/org/apache/spark/SparkEnv.scala
@@ -156,7 +156,15 @@ object SparkEnv extends Logging {
assert(conf.contains("spark.driver.port"), "spark.driver.port is not set on the driver!")
val hostname = conf.get("spark.driver.host")
val port = conf.get("spark.driver.port").toInt
- create(conf, SparkContext.DRIVER_IDENTIFIER, hostname, port, true, isLocal, listenerBus)
+ create(
+ conf,
+ SparkContext.DRIVER_IDENTIFIER,
+ hostname,
+ port,
+ isDriver = true,
+ isLocal = isLocal,
+ listenerBus = listenerBus
+ )
}
/**
@@ -168,9 +176,19 @@ object SparkEnv extends Logging {
executorId: String,
hostname: String,
port: Int,
+ numCores: Int,
isLocal: Boolean,
actorSystem: ActorSystem = null): SparkEnv = {
- create(conf, executorId, hostname, port, false, isLocal, defaultActorSystem = actorSystem)
+ create(
+ conf,
+ executorId,
+ hostname,
+ port,
+ isDriver = false,
+ isLocal = isLocal,
+ defaultActorSystem = actorSystem,
+ numUsableCores = numCores
+ )
}
/**
@@ -184,7 +202,8 @@ object SparkEnv extends Logging {
isDriver: Boolean,
isLocal: Boolean,
listenerBus: LiveListenerBus = null,
- defaultActorSystem: ActorSystem = null): SparkEnv = {
+ defaultActorSystem: ActorSystem = null,
+ numUsableCores: Int = 0): SparkEnv = {
// Listener bus is only used on the driver
if (isDriver) {
@@ -276,7 +295,7 @@ object SparkEnv extends Logging {
val blockTransferService =
conf.get("spark.shuffle.blockTransferService", "netty").toLowerCase match {
case "netty" =>
- new NettyBlockTransferService(conf)
+ new NettyBlockTransferService(conf, securityManager, numUsableCores)
case "nio" =>
new NioBlockTransferService(conf, securityManager)
}
@@ -285,8 +304,10 @@ object SparkEnv extends Logging {
"BlockManagerMaster",
new BlockManagerMasterActor(isLocal, conf, listenerBus)), conf, isDriver)
+ // NB: blockManager is not valid until initialize() is called later.
val blockManager = new BlockManager(executorId, actorSystem, blockManagerMaster,
- serializer, conf, mapOutputTracker, shuffleManager, blockTransferService)
+ serializer, conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager,
+ numUsableCores)
val broadcastManager = new BroadcastManager(isDriver, conf, securityManager)
@@ -295,7 +316,7 @@ object SparkEnv extends Logging {
val httpFileServer =
if (isDriver) {
val fileServerPort = conf.getInt("spark.fileserver.port", 0)
- val server = new HttpFileServer(securityManager, fileServerPort)
+ val server = new HttpFileServer(conf, securityManager, fileServerPort)
server.initialize()
conf.set("spark.fileserver.uri", server.serverUri)
server
@@ -309,6 +330,10 @@ object SparkEnv extends Logging {
// Then we can start the metrics system.
MetricsSystem.createMetricsSystem("driver", conf, securityManager)
} else {
+ // We need to set the executor ID before the MetricsSystem is created because sources and
+ // sinks specified in the metrics configuration file will want to incorporate this executor's
+ // ID into the metrics they report.
+ conf.set("spark.executor.id", executorId)
val ms = MetricsSystem.createMetricsSystem("executor", conf, securityManager)
ms.start()
ms
@@ -318,7 +343,7 @@ object SparkEnv extends Logging {
// this is a temporary directory; in distributed mode, this is the executor's current working
// directory.
val sparkFilesDir: String = if (isDriver) {
- Utils.createTempDir().getAbsolutePath
+ Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath
} else {
"."
}
@@ -378,7 +403,7 @@ object SparkEnv extends Logging {
val sparkProperties = (conf.getAll ++ schedulerMode).sorted
// System properties that are not java classpaths
- val systemProperties = System.getProperties.iterator.toSeq
+ val systemProperties = Utils.getSystemProperties.toSeq
val otherProperties = systemProperties.filter { case (k, _) =>
k != "java.class.path" && !k.startsWith("spark.")
}.sorted
diff --git a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
index 376e69cd997d5..40237596570de 100644
--- a/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
+++ b/core/src/main/scala/org/apache/spark/SparkHadoopWriter.scala
@@ -26,6 +26,7 @@ import org.apache.hadoop.mapred._
import org.apache.hadoop.fs.FileSystem
import org.apache.hadoop.fs.Path
+import org.apache.spark.mapred.SparkHadoopMapRedUtil
import org.apache.spark.rdd.HadoopRDD
/**
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala b/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
deleted file mode 100644
index a954fcc0c31fa..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslClient.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.RealmChoiceCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslClient
-import javax.security.sasl.SaslException
-
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-
-/**
- * Implements SASL Client logic for Spark
- */
-private[spark] class SparkSaslClient(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Used to respond to server's counterpart, SaslServer with SASL tokens
- * represented as byte arrays.
- *
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- private var saslClient: SaslClient = Sasl.createSaslClient(Array[String](SparkSaslServer.DIGEST),
- null, null, SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslClientCallbackHandler(securityMgr))
-
- /**
- * Used to initiate SASL handshake with server.
- * @return response to challenge if needed
- */
- def firstToken(): Array[Byte] = {
- synchronized {
- val saslToken: Array[Byte] =
- if (saslClient != null && saslClient.hasInitialResponse()) {
- logDebug("has initial response")
- saslClient.evaluateChallenge(new Array[Byte](0))
- } else {
- new Array[Byte](0)
- }
- saslToken
- }
- }
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslClient != null) saslClient.isComplete() else false
- }
- }
-
- /**
- * Respond to server's SASL token.
- * @param saslTokenMessage contains server's SASL token
- * @return client's response SASL token
- */
- def saslResponse(saslTokenMessage: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslClient != null) saslClient.evaluateChallenge(saslTokenMessage) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslClient might be using.
- */
- def dispose() {
- synchronized {
- if (saslClient != null) {
- try {
- saslClient.dispose()
- } catch {
- case e: SaslException => // ignored
- } finally {
- saslClient = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * that works with share secrets.
- */
- private class SparkSaslClientCallbackHandler(securityMgr: SecurityManager) extends
- CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
- private val secretKey = securityMgr.getSecretKey()
- private val userPassword: Array[Char] = SparkSaslServer.encodePassword(
- if (secretKey != null) secretKey.getBytes(UTF_8) else "".getBytes(UTF_8))
-
- /**
- * Implementation used to respond to SASL request from the server.
- *
- * @param callbacks objects that indicate what credential information the
- * server's SaslServer requires from the client.
- */
- override def handle(callbacks: Array[Callback]) {
- logDebug("in the sasl client callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL client callback: setting username: " + userName)
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL client callback: setting userPassword")
- pc.setPassword(userPassword)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL client callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case cb: RealmChoiceCallback => {}
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL client callback")
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala b/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
deleted file mode 100644
index 7c2afb364661f..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkSaslServer.scala
+++ /dev/null
@@ -1,176 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import javax.security.auth.callback.Callback
-import javax.security.auth.callback.CallbackHandler
-import javax.security.auth.callback.NameCallback
-import javax.security.auth.callback.PasswordCallback
-import javax.security.auth.callback.UnsupportedCallbackException
-import javax.security.sasl.AuthorizeCallback
-import javax.security.sasl.RealmCallback
-import javax.security.sasl.Sasl
-import javax.security.sasl.SaslException
-import javax.security.sasl.SaslServer
-import scala.collection.JavaConversions.mapAsJavaMap
-
-import com.google.common.base.Charsets.UTF_8
-import org.apache.commons.net.util.Base64
-
-/**
- * Encapsulates SASL server logic
- */
-private[spark] class SparkSaslServer(securityMgr: SecurityManager) extends Logging {
-
- /**
- * Actual SASL work done by this object from javax.security.sasl.
- */
- private var saslServer: SaslServer = Sasl.createSaslServer(SparkSaslServer.DIGEST, null,
- SparkSaslServer.SASL_DEFAULT_REALM, SparkSaslServer.SASL_PROPS,
- new SparkSaslDigestCallbackHandler(securityMgr))
-
- /**
- * Determines whether the authentication exchange has completed.
- * @return true is complete, otherwise false
- */
- def isComplete(): Boolean = {
- synchronized {
- if (saslServer != null) saslServer.isComplete() else false
- }
- }
-
- /**
- * Used to respond to server SASL tokens.
- * @param token Server's SASL token
- * @return response to send back to the server.
- */
- def response(token: Array[Byte]): Array[Byte] = {
- synchronized {
- if (saslServer != null) saslServer.evaluateResponse(token) else new Array[Byte](0)
- }
- }
-
- /**
- * Disposes of any system resources or security-sensitive information the
- * SaslServer might be using.
- */
- def dispose() {
- synchronized {
- if (saslServer != null) {
- try {
- saslServer.dispose()
- } catch {
- case e: SaslException => // ignore
- } finally {
- saslServer = null
- }
- }
- }
- }
-
- /**
- * Implementation of javax.security.auth.callback.CallbackHandler
- * for SASL DIGEST-MD5 mechanism
- */
- private class SparkSaslDigestCallbackHandler(securityMgr: SecurityManager)
- extends CallbackHandler {
-
- private val userName: String =
- SparkSaslServer.encodeIdentifier(securityMgr.getSaslUser().getBytes(UTF_8))
-
- override def handle(callbacks: Array[Callback]) {
- logDebug("In the sasl server callback handler")
- callbacks foreach {
- case nc: NameCallback => {
- logDebug("handle: SASL server callback: setting username")
- nc.setName(userName)
- }
- case pc: PasswordCallback => {
- logDebug("handle: SASL server callback: setting userPassword")
- val password: Array[Char] =
- SparkSaslServer.encodePassword(securityMgr.getSecretKey().getBytes(UTF_8))
- pc.setPassword(password)
- }
- case rc: RealmCallback => {
- logDebug("handle: SASL server callback: setting realm: " + rc.getDefaultText())
- rc.setText(rc.getDefaultText())
- }
- case ac: AuthorizeCallback => {
- val authid = ac.getAuthenticationID()
- val authzid = ac.getAuthorizationID()
- if (authid.equals(authzid)) {
- logDebug("set auth to true")
- ac.setAuthorized(true)
- } else {
- logDebug("set auth to false")
- ac.setAuthorized(false)
- }
- if (ac.isAuthorized()) {
- logDebug("sasl server is authorized")
- ac.setAuthorizedID(authzid)
- }
- }
- case cb: Callback => throw
- new UnsupportedCallbackException(cb, "handle: Unrecognized SASL DIGEST-MD5 Callback")
- }
- }
- }
-}
-
-private[spark] object SparkSaslServer {
-
- /**
- * This is passed as the server name when creating the sasl client/server.
- * This could be changed to be configurable in the future.
- */
- val SASL_DEFAULT_REALM = "default"
-
- /**
- * The authentication mechanism used here is DIGEST-MD5. This could be changed to be
- * configurable in the future.
- */
- val DIGEST = "DIGEST-MD5"
-
- /**
- * The quality of protection is just "auth". This means that we are doing
- * authentication only, we are not supporting integrity or privacy protection of the
- * communication channel after authentication. This could be changed to be configurable
- * in the future.
- */
- val SASL_PROPS = Map(Sasl.QOP -> "auth", Sasl.SERVER_AUTH ->"true")
-
- /**
- * Encode a byte[] identifier as a Base64-encoded string.
- *
- * @param identifier identifier to encode
- * @return Base64-encoded string
- */
- def encodeIdentifier(identifier: Array[Byte]): String = {
- new String(Base64.encodeBase64(identifier), UTF_8)
- }
-
- /**
- * Encode a password as a base64-encoded char[] array.
- * @param password as a byte array.
- * @return password as a char array.
- */
- def encodePassword(password: Array[Byte]): Array[Char] = {
- new String(Base64.encodeBase64(password), UTF_8).toCharArray()
- }
-}
-
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala b/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala
deleted file mode 100644
index 1982499c5e1d3..0000000000000
--- a/core/src/main/scala/org/apache/spark/SparkStatusAPI.scala
+++ /dev/null
@@ -1,142 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements. See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License. You may obtain a copy of the License at
- *
- * http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-package org.apache.spark
-
-import scala.collection.Map
-import scala.collection.JavaConversions._
-
-import org.apache.spark.annotation.DeveloperApi
-import org.apache.spark.rdd.RDD
-import org.apache.spark.scheduler.{SchedulingMode, Schedulable}
-import org.apache.spark.storage.{StorageStatus, StorageUtils, RDDInfo}
-
-/**
- * Trait that implements Spark's status APIs. This trait is designed to be mixed into
- * SparkContext; it allows the status API code to live in its own file.
- */
-private[spark] trait SparkStatusAPI { this: SparkContext =>
-
- /**
- * Return a map from the slave to the max memory available for caching and the remaining
- * memory available for caching.
- */
- def getExecutorMemoryStatus: Map[String, (Long, Long)] = {
- env.blockManager.master.getMemoryStatus.map { case(blockManagerId, mem) =>
- (blockManagerId.host + ":" + blockManagerId.port, mem)
- }
- }
-
- /**
- * :: DeveloperApi ::
- * Return information about what RDDs are cached, if they are in mem or on disk, how much space
- * they take, etc.
- */
- @DeveloperApi
- def getRDDStorageInfo: Array[RDDInfo] = {
- val rddInfos = persistentRdds.values.map(RDDInfo.fromRdd).toArray
- StorageUtils.updateRddInfo(rddInfos, getExecutorStorageStatus)
- rddInfos.filter(_.isCached)
- }
-
- /**
- * Returns an immutable map of RDDs that have marked themselves as persistent via cache() call.
- * Note that this does not necessarily mean the caching or computation was successful.
- */
- def getPersistentRDDs: Map[Int, RDD[_]] = persistentRdds.toMap
-
- /**
- * :: DeveloperApi ::
- * Return information about blocks stored in all of the slaves
- */
- @DeveloperApi
- def getExecutorStorageStatus: Array[StorageStatus] = {
- env.blockManager.master.getStorageStatus
- }
-
- /**
- * :: DeveloperApi ::
- * Return pools for fair scheduler
- */
- @DeveloperApi
- def getAllPools: Seq[Schedulable] = {
- // TODO(xiajunluan): We should take nested pools into account
- taskScheduler.rootPool.schedulableQueue.toSeq
- }
-
- /**
- * :: DeveloperApi ::
- * Return the pool associated with the given name, if one exists
- */
- @DeveloperApi
- def getPoolForName(pool: String): Option[Schedulable] = {
- Option(taskScheduler.rootPool.schedulableNameToSchedulable.get(pool))
- }
-
- /**
- * Return current scheduling mode
- */
- def getSchedulingMode: SchedulingMode.SchedulingMode = {
- taskScheduler.schedulingMode
- }
-
-
- /**
- * Return a list of all known jobs in a particular job group. The returned list may contain
- * running, failed, and completed jobs, and may vary across invocations of this method. This
- * method does not guarantee the order of the elements in its result.
- */
- def getJobIdsForGroup(jobGroup: String): Array[Int] = {
- jobProgressListener.synchronized {
- val jobData = jobProgressListener.jobIdToData.valuesIterator
- jobData.filter(_.jobGroup.exists(_ == jobGroup)).map(_.jobId).toArray
- }
- }
-
- /**
- * Returns job information, or `None` if the job info could not be found or was garbage collected.
- */
- def getJobInfo(jobId: Int): Option[SparkJobInfo] = {
- jobProgressListener.synchronized {
- jobProgressListener.jobIdToData.get(jobId).map { data =>
- new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status)
- }
- }
- }
-
- /**
- * Returns stage information, or `None` if the stage info could not be found or was
- * garbage collected.
- */
- def getStageInfo(stageId: Int): Option[SparkStageInfo] = {
- jobProgressListener.synchronized {
- for (
- info <- jobProgressListener.stageIdToInfo.get(stageId);
- data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId))
- ) yield {
- new SparkStageInfoImpl(
- stageId,
- info.attemptId,
- info.name,
- info.numTasks,
- data.numActiveTasks,
- data.numCompleteTasks,
- data.numFailedTasks)
- }
- }
- }
-}
diff --git a/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
new file mode 100644
index 0000000000000..edbdda8a0bcb6
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/SparkStatusTracker.scala
@@ -0,0 +1,108 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `None` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class SparkStatusTracker private[spark] (sc: SparkContext) {
+
+ private val jobProgressListener = sc.jobProgressListener
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = {
+ jobProgressListener.synchronized {
+ val jobData = jobProgressListener.jobIdToData.valuesIterator
+ jobData.filter(_.jobGroup.orNull == jobGroup).map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeStages.values.map(_.stageId).toArray
+ }
+ }
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.activeJobs.values.map(_.jobId).toArray
+ }
+ }
+
+ /**
+ * Returns job information, or `None` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): Option[SparkJobInfo] = {
+ jobProgressListener.synchronized {
+ jobProgressListener.jobIdToData.get(jobId).map { data =>
+ new SparkJobInfoImpl(jobId, data.stageIds.toArray, data.status)
+ }
+ }
+ }
+
+ /**
+ * Returns stage information, or `None` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): Option[SparkStageInfo] = {
+ jobProgressListener.synchronized {
+ for (
+ info <- jobProgressListener.stageIdToInfo.get(stageId);
+ data <- jobProgressListener.stageIdToData.get((stageId, info.attemptId))
+ ) yield {
+ new SparkStageInfoImpl(
+ stageId,
+ info.attemptId,
+ info.submissionTime.getOrElse(0),
+ info.name,
+ info.numTasks,
+ data.numActiveTasks,
+ data.numCompleteTasks,
+ data.numFailedTasks)
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
index 90b47c847fbca..e5c7c8d0db578 100644
--- a/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
+++ b/core/src/main/scala/org/apache/spark/StatusAPIImpl.scala
@@ -26,6 +26,7 @@ private class SparkJobInfoImpl (
private class SparkStageInfoImpl(
val stageId: Int,
val currentAttemptId: Int,
+ val submissionTime: Long,
val name: String,
val numTasks: Int,
val numActiveTasks: Int,
diff --git a/core/src/main/scala/org/apache/spark/TaskEndReason.scala b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
index f45b463fb6f62..af5fd8e0ac00c 100644
--- a/core/src/main/scala/org/apache/spark/TaskEndReason.scala
+++ b/core/src/main/scala/org/apache/spark/TaskEndReason.scala
@@ -83,15 +83,48 @@ case class FetchFailed(
* :: DeveloperApi ::
* Task failed due to a runtime exception. This is the most common failure case and also captures
* user program exceptions.
+ *
+ * `stackTrace` contains the stack trace of the exception itself. It still exists for backward
+ * compatibility. It's better to use `this(e: Throwable, metrics: Option[TaskMetrics])` to
+ * create `ExceptionFailure` as it will handle the backward compatibility properly.
+ *
+ * `fullStackTrace` is a better representation of the stack trace because it contains the whole
+ * stack trace including the exception and its causes
*/
@DeveloperApi
case class ExceptionFailure(
className: String,
description: String,
stackTrace: Array[StackTraceElement],
+ fullStackTrace: String,
metrics: Option[TaskMetrics])
extends TaskFailedReason {
- override def toErrorString: String = Utils.exceptionString(className, description, stackTrace)
+
+ private[spark] def this(e: Throwable, metrics: Option[TaskMetrics]) {
+ this(e.getClass.getName, e.getMessage, e.getStackTrace, Utils.exceptionString(e), metrics)
+ }
+
+ override def toErrorString: String =
+ if (fullStackTrace == null) {
+ // fullStackTrace is added in 1.2.0
+ // If fullStackTrace is null, use the old error string for backward compatibility
+ exceptionString(className, description, stackTrace)
+ } else {
+ fullStackTrace
+ }
+
+ /**
+ * Return a nice string representation of the exception, including the stack trace.
+ * Note: It does not include the exception's causes, and is only used for backward compatibility.
+ */
+ private def exceptionString(
+ className: String,
+ description: String,
+ stackTrace: Array[StackTraceElement]): String = {
+ val desc = if (description == null) "" else description
+ val st = if (stackTrace == null) "" else stackTrace.map(" " + _).mkString("\n")
+ s"$className: $desc\n$st"
+ }
}
/**
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
index efb8978f7ce12..fa2c1c28c970d 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaRDDLike.scala
@@ -212,8 +212,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K]): JavaPairRDD[K, JIterable[T]] = {
- implicit val ctagK: ClassTag[K] = fakeClassTag
+ def groupBy[U](f: JFunction[T, U]): JavaPairRDD[U, JIterable[T]] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctagK: ClassTag[U] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f)(fakeClassTag)))
}
@@ -222,10 +223,11 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
* Return an RDD of grouped elements. Each group consists of a key and a sequence of elements
* mapping to that key.
*/
- def groupBy[K](f: JFunction[T, K], numPartitions: Int): JavaPairRDD[K, JIterable[T]] = {
- implicit val ctagK: ClassTag[K] = fakeClassTag
+ def groupBy[U](f: JFunction[T, U], numPartitions: Int): JavaPairRDD[U, JIterable[T]] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctagK: ClassTag[U] = fakeClassTag
implicit val ctagV: ClassTag[JList[T]] = fakeClassTag
- JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[K])))
+ JavaPairRDD.fromRDD(groupByResultToJava(rdd.groupBy(f, numPartitions)(fakeClassTag[U])))
}
/**
@@ -459,8 +461,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
/**
* Creates tuples of the elements in this RDD by applying `f`.
*/
- def keyBy[K](f: JFunction[T, K]): JavaPairRDD[K, T] = {
- implicit val ctag: ClassTag[K] = fakeClassTag
+ def keyBy[U](f: JFunction[T, U]): JavaPairRDD[U, T] = {
+ // The type parameter is U instead of K in order to work around a compiler bug; see SPARK-4459
+ implicit val ctag: ClassTag[U] = fakeClassTag
JavaPairRDD.fromRDD(rdd.keyBy(f))
}
@@ -493,9 +496,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD as defined by
+ * Returns the top k (largest) elements from this RDD as defined by
* the specified Comparator[T].
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -507,9 +510,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the top K elements from this RDD using the
+ * Returns the top k (largest) elements from this RDD using the
* natural ordering for T.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def top(num: Int): JList[T] = {
@@ -518,9 +521,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD as defined by
+ * Returns the first k (smallest) elements from this RDD as defined by
* the specified Comparator[T] and maintains the order.
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param comp the comparator that defines the order
* @return an array of top elements
*/
@@ -552,9 +555,9 @@ trait JavaRDDLike[T, This <: JavaRDDLike[T, This]] extends Serializable {
}
/**
- * Returns the first K elements from this RDD using the
+ * Returns the first k (smallest) elements from this RDD using the
* natural ordering for T while maintain the order.
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @return an array of top elements
*/
def takeOrdered(num: Int): JList[T] = {
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
index e3aeba7e6c39d..6a6d9bf6857d3 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkContext.scala
@@ -21,11 +21,6 @@ import java.io.Closeable
import java.util
import java.util.{Map => JMap}
-import java.io.DataInputStream
-
-import org.apache.hadoop.io.{BytesWritable, LongWritable}
-import org.apache.spark.input.{PortableDataStream, FixedLengthBinaryInputFormat}
-
import scala.collection.JavaConversions
import scala.collection.JavaConversions._
import scala.language.implicitConversions
@@ -33,6 +28,7 @@ import scala.reflect.ClassTag
import com.google.common.base.Optional
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.input.PortableDataStream
import org.apache.hadoop.mapred.{InputFormat, JobConf}
import org.apache.hadoop.mapreduce.{InputFormat => NewInputFormat}
@@ -46,6 +42,9 @@ import org.apache.spark.rdd.{EmptyRDD, HadoopRDD, NewHadoopRDD, RDD}
/**
* A Java-friendly version of [[org.apache.spark.SparkContext]] that returns
* [[org.apache.spark.api.java.JavaRDD]]s and works with Java collections instead of Scala ones.
+ *
+ * Only one SparkContext may be active per JVM. You must `stop()` the active SparkContext before
+ * creating a new one. This limitation may eventually be removed; see SPARK-2243 for more details.
*/
class JavaSparkContext(val sc: SparkContext)
extends JavaSparkContextVarargsWorkaround with Closeable {
@@ -109,6 +108,8 @@ class JavaSparkContext(val sc: SparkContext)
private[spark] val env = sc.env
+ def statusTracker = new JavaSparkStatusTracker(sc)
+
def isLocal: java.lang.Boolean = sc.isLocal
def sparkUser: String = sc.sparkUser
@@ -138,25 +139,6 @@ class JavaSparkContext(val sc: SparkContext)
/** Default min number of partitions for Hadoop RDDs when not given by user */
def defaultMinPartitions: java.lang.Integer = sc.defaultMinPartitions
-
- /**
- * Return a list of all known jobs in a particular job group. The returned list may contain
- * running, failed, and completed jobs, and may vary across invocations of this method. This
- * method does not guarantee the order of the elements in its result.
- */
- def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.getJobIdsForGroup(jobGroup)
-
- /**
- * Returns job information, or `null` if the job info could not be found or was garbage collected.
- */
- def getJobInfo(jobId: Int): SparkJobInfo = sc.getJobInfo(jobId).orNull
-
- /**
- * Returns stage information, or `null` if the stage info could not be found or was
- * garbage collected.
- */
- def getStageInfo(stageId: Int): SparkStageInfo = sc.getStageInfo(stageId).orNull
-
/** Distribute a local Scala collection to form an RDD. */
def parallelize[T](list: java.util.List[T], numSlices: Int): JavaRDD[T] = {
implicit val ctag: ClassTag[T] = fakeClassTag
@@ -286,6 +268,8 @@ class JavaSparkContext(val sc: SparkContext)
new JavaPairRDD(sc.binaryFiles(path, minPartitions))
/**
+ * :: Experimental ::
+ *
* Read a directory of binary files from HDFS, a local file system (available on all nodes),
* or any Hadoop-supported file system URI as a byte array. Each file is read as a single
* record and returned in a key-value pair, where the key is the path of each file,
@@ -312,15 +296,19 @@ class JavaSparkContext(val sc: SparkContext)
*
* @note Small files are preferred; very large files but may cause bad performance.
*/
+ @Experimental
def binaryFiles(path: String): JavaPairRDD[String, PortableDataStream] =
new JavaPairRDD(sc.binaryFiles(path, defaultMinPartitions))
/**
+ * :: Experimental ::
+ *
* Load data from a flat binary file, assuming the length of each record is constant.
*
* @param path Directory to the input data files
* @return An RDD of data with values, represented as byte arrays
*/
+ @Experimental
def binaryRecords(path: String, recordLength: Int): JavaRDD[Array[Byte]] = {
new JavaRDD(sc.binaryRecords(path, recordLength))
}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
new file mode 100644
index 0000000000000..3300cad9efbab
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaSparkStatusTracker.scala
@@ -0,0 +1,72 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.api.java
+
+import org.apache.spark.{SparkStageInfo, SparkJobInfo, SparkContext}
+
+/**
+ * Low-level status reporting APIs for monitoring job and stage progress.
+ *
+ * These APIs intentionally provide very weak consistency semantics; consumers of these APIs should
+ * be prepared to handle empty / missing information. For example, a job's stage ids may be known
+ * but the status API may not have any information about the details of those stages, so
+ * `getStageInfo` could potentially return `null` for a valid stage id.
+ *
+ * To limit memory usage, these APIs only provide information on recent jobs / stages. These APIs
+ * will provide information for the last `spark.ui.retainedStages` stages and
+ * `spark.ui.retainedJobs` jobs.
+ *
+ * NOTE: this class's constructor should be considered private and may be subject to change.
+ */
+class JavaSparkStatusTracker private[spark] (sc: SparkContext) {
+
+ /**
+ * Return a list of all known jobs in a particular job group. If `jobGroup` is `null`, then
+ * returns all known jobs that are not associated with a job group.
+ *
+ * The returned list may contain running, failed, and completed jobs, and may vary across
+ * invocations of this method. This method does not guarantee the order of the elements in
+ * its result.
+ */
+ def getJobIdsForGroup(jobGroup: String): Array[Int] = sc.statusTracker.getJobIdsForGroup(jobGroup)
+
+ /**
+ * Returns an array containing the ids of all active stages.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveStageIds(): Array[Int] = sc.statusTracker.getActiveStageIds()
+
+ /**
+ * Returns an array containing the ids of all active jobs.
+ *
+ * This method does not guarantee the order of the elements in its result.
+ */
+ def getActiveJobIds(): Array[Int] = sc.statusTracker.getActiveJobIds()
+
+ /**
+ * Returns job information, or `null` if the job info could not be found or was garbage collected.
+ */
+ def getJobInfo(jobId: Int): SparkJobInfo = sc.statusTracker.getJobInfo(jobId).orNull
+
+ /**
+ * Returns stage information, or `null` if the stage info could not be found or was
+ * garbage collected.
+ */
+ def getStageInfo(stageId: Int): SparkStageInfo = sc.statusTracker.getStageInfo(stageId).orNull
+}
diff --git a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
index b52d0a5028e84..86e94931300f8 100644
--- a/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
+++ b/core/src/main/scala/org/apache/spark/api/java/JavaUtils.scala
@@ -19,7 +19,8 @@ package org.apache.spark.api.java
import com.google.common.base.Optional
-import scala.collection.convert.Wrappers.MapWrapper
+import java.{util => ju}
+import scala.collection.mutable
private[spark] object JavaUtils {
def optionToOptional[T](option: Option[T]): Optional[T] =
@@ -32,7 +33,64 @@ private[spark] object JavaUtils {
def mapAsSerializableJavaMap[A, B](underlying: collection.Map[A, B]) =
new SerializableMapWrapper(underlying)
+ // Implementation is copied from scala.collection.convert.Wrappers.MapWrapper,
+ // but implements java.io.Serializable. It can't just be subclassed to make it
+ // Serializable since the MapWrapper class has no no-arg constructor. This class
+ // doesn't need a no-arg constructor though.
class SerializableMapWrapper[A, B](underlying: collection.Map[A, B])
- extends MapWrapper(underlying) with java.io.Serializable
+ extends ju.AbstractMap[A, B] with java.io.Serializable { self =>
+ override def size = underlying.size
+
+ override def get(key: AnyRef): B = try {
+ underlying get key.asInstanceOf[A] match {
+ case None => null.asInstanceOf[B]
+ case Some(v) => v
+ }
+ } catch {
+ case ex: ClassCastException => null.asInstanceOf[B]
+ }
+
+ override def entrySet: ju.Set[ju.Map.Entry[A, B]] = new ju.AbstractSet[ju.Map.Entry[A, B]] {
+ def size = self.size
+
+ def iterator = new ju.Iterator[ju.Map.Entry[A, B]] {
+ val ui = underlying.iterator
+ var prev : Option[A] = None
+
+ def hasNext = ui.hasNext
+
+ def next() = {
+ val (k, v) = ui.next
+ prev = Some(k)
+ new ju.Map.Entry[A, B] {
+ import scala.util.hashing.byteswap32
+ def getKey = k
+ def getValue = v
+ def setValue(v1 : B) = self.put(k, v1)
+ override def hashCode = byteswap32(k.hashCode) + (byteswap32(v.hashCode) << 16)
+ override def equals(other: Any) = other match {
+ case e: ju.Map.Entry[_, _] => k == e.getKey && v == e.getValue
+ case _ => false
+ }
+ }
+ }
+
+ def remove() {
+ prev match {
+ case Some(k) =>
+ underlying match {
+ case mm: mutable.Map[a, _] =>
+ mm remove k
+ prev = None
+ case _ =>
+ throw new UnsupportedOperationException("remove")
+ }
+ case _ =>
+ throw new IllegalStateException("next must be called at least once before remove")
+ }
+ }
+ }
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
index 49dc95f349eac..5ba66178e2b78 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonHadoopUtil.scala
@@ -61,8 +61,7 @@ private[python] object Converter extends Logging {
* Other objects are passed through without conversion.
*/
private[python] class WritableToJavaConverter(
- conf: Broadcast[SerializableWritable[Configuration]],
- batchSize: Int) extends Converter[Any, Any] {
+ conf: Broadcast[SerializableWritable[Configuration]]) extends Converter[Any, Any] {
/**
* Converts a [[org.apache.hadoop.io.Writable]] to the underlying primitive, String or
@@ -94,8 +93,7 @@ private[python] class WritableToJavaConverter(
map.put(convertWritable(k), convertWritable(v))
}
map
- case w: Writable =>
- if (batchSize > 1) WritableUtils.clone(w, conf.value.value) else w
+ case w: Writable => WritableUtils.clone(w, conf.value.value)
case other => other
}
}
diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
index 61b125ef7c6c1..e0bc00e1eb249 100644
--- a/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/PythonRDD.scala
@@ -19,15 +19,15 @@ package org.apache.spark.api.python
import java.io._
import java.net._
-import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, Collections}
+import java.util.{List => JList, ArrayList => JArrayList, Map => JMap, UUID, Collections}
+
+import org.apache.spark.input.PortableDataStream
import scala.collection.JavaConversions._
-import scala.collection.JavaConverters._
import scala.collection.mutable
import scala.language.existentials
import com.google.common.base.Charsets.UTF_8
-import net.razorvine.pickle.{Pickler, Unpickler}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.io.compress.CompressionCodec
@@ -47,7 +47,7 @@ private[spark] class PythonRDD(
pythonIncludes: JList[String],
preservePartitoning: Boolean,
pythonExec: String,
- broadcastVars: JList[Broadcast[Array[Byte]]],
+ broadcastVars: JList[Broadcast[PythonBroadcast]],
accumulator: Accumulator[JList[Array[Byte]]])
extends RDD[Array[Byte]](parent) {
@@ -230,8 +230,7 @@ private[spark] class PythonRDD(
if (!oldBids.contains(broadcast.id)) {
// send new broadcast
dataOut.writeLong(broadcast.id)
- dataOut.writeInt(broadcast.value.length)
- dataOut.write(broadcast.value)
+ PythonRDD.writeUTF(broadcast.value.path, dataOut)
oldBids.add(broadcast.id)
}
}
@@ -368,16 +367,8 @@ private[spark] object PythonRDD extends Logging {
}
}
- def readBroadcastFromFile(sc: JavaSparkContext, filename: String): Broadcast[Array[Byte]] = {
- val file = new DataInputStream(new FileInputStream(filename))
- try {
- val length = file.readInt()
- val obj = new Array[Byte](length)
- file.readFully(obj)
- sc.broadcast(obj)
- } finally {
- file.close()
- }
+ def readBroadcastFromFile(sc: JavaSparkContext, path: String): Broadcast[PythonBroadcast] = {
+ sc.broadcast(new PythonBroadcast(path))
}
def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream) {
@@ -397,22 +388,33 @@ private[spark] object PythonRDD extends Logging {
newIter.asInstanceOf[Iterator[String]].foreach { str =>
writeUTF(str, dataOut)
}
- case pair: Tuple2[_, _] =>
- pair._1 match {
- case bytePair: Array[Byte] =>
- newIter.asInstanceOf[Iterator[Tuple2[Array[Byte], Array[Byte]]]].foreach { pair =>
- dataOut.writeInt(pair._1.length)
- dataOut.write(pair._1)
- dataOut.writeInt(pair._2.length)
- dataOut.write(pair._2)
- }
- case stringPair: String =>
- newIter.asInstanceOf[Iterator[Tuple2[String, String]]].foreach { pair =>
- writeUTF(pair._1, dataOut)
- writeUTF(pair._2, dataOut)
- }
- case other =>
- throw new SparkException("Unexpected Tuple2 element type " + pair._1.getClass)
+ case stream: PortableDataStream =>
+ newIter.asInstanceOf[Iterator[PortableDataStream]].foreach { stream =>
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, stream: PortableDataStream) =>
+ newIter.asInstanceOf[Iterator[(String, PortableDataStream)]].foreach {
+ case (key, stream) =>
+ writeUTF(key, dataOut)
+ val bytes = stream.toArray()
+ dataOut.writeInt(bytes.length)
+ dataOut.write(bytes)
+ }
+ case (key: String, value: String) =>
+ newIter.asInstanceOf[Iterator[(String, String)]].foreach {
+ case (key, value) =>
+ writeUTF(key, dataOut)
+ writeUTF(value, dataOut)
+ }
+ case (key: Array[Byte], value: Array[Byte]) =>
+ newIter.asInstanceOf[Iterator[(Array[Byte], Array[Byte])]].foreach {
+ case (key, value) =>
+ dataOut.writeInt(key.length)
+ dataOut.write(key)
+ dataOut.writeInt(value.length)
+ dataOut.write(value)
}
case other =>
throw new SparkException("Unexpected element type " + first.getClass)
@@ -442,7 +444,7 @@ private[spark] object PythonRDD extends Logging {
val rdd = sc.sc.sequenceFile[K, V](path, kc, vc, minSplits)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(sc.hadoopConfiguration()))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -468,7 +470,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -494,7 +496,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -537,7 +539,7 @@ private[spark] object PythonRDD extends Logging {
Some(path), inputFormatClass, keyClass, valueClass, mergedConf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(mergedConf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -563,7 +565,7 @@ private[spark] object PythonRDD extends Logging {
None, inputFormatClass, keyClass, valueClass, conf)
val confBroadcasted = sc.sc.broadcast(new SerializableWritable(conf))
val converted = convertRDD(rdd, keyConverterClass, valueConverterClass,
- new WritableToJavaConverter(confBroadcasted, batchSize))
+ new WritableToJavaConverter(confBroadcasted))
JavaRDD.fromRDD(SerDeUtil.pairRDDToPython(converted, batchSize))
}
@@ -746,104 +748,6 @@ private[spark] object PythonRDD extends Logging {
converted.saveAsHadoopDataset(new JobConf(conf))
}
}
-
-
- /**
- * Convert an RDD of serialized Python dictionaries to Scala Maps (no recursive conversions).
- */
- @deprecated("PySpark does not use it anymore", "1.1")
- def pythonToJavaMap(pyRDD: JavaRDD[Array[Byte]]): JavaRDD[Map[String, _]] = {
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- SerDeUtil.initialize()
- iter.flatMap { row =>
- unpickle.loads(row) match {
- // in case of objects are pickled in batch mode
- case objs: JArrayList[JMap[String, _] @unchecked] => objs.map(_.toMap)
- // not in batch mode
- case obj: JMap[String @unchecked, _] => Seq(obj.toMap)
- }
- }
- }
- }
-
- /**
- * Convert an RDD of serialized Python tuple to Array (no recursive conversions).
- * It is only used by pyspark.sql.
- */
- def pythonToJavaArray(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Array[_]] = {
-
- def toArray(obj: Any): Array[_] = {
- obj match {
- case objs: JArrayList[_] =>
- objs.toArray
- case obj if obj.getClass.isArray =>
- obj.asInstanceOf[Array[_]].toArray
- }
- }
-
- pyRDD.rdd.mapPartitions { iter =>
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].map(toArray)
- } else {
- Seq(toArray(obj))
- }
- }
- }.toJavaRDD()
- }
-
- private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
- private val pickle = new Pickler()
- private var batch = 1
- private val buffer = new mutable.ArrayBuffer[Any]
-
- override def hasNext(): Boolean = iter.hasNext
-
- override def next(): Array[Byte] = {
- while (iter.hasNext && buffer.length < batch) {
- buffer += iter.next()
- }
- val bytes = pickle.dumps(buffer.toArray)
- val size = bytes.length
- // let 1M < size < 10M
- if (size < 1024 * 1024) {
- batch *= 2
- } else if (size > 1024 * 1024 * 10 && batch > 1) {
- batch /= 2
- }
- buffer.clear()
- bytes
- }
- }
-
- /**
- * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
- * PySpark.
- */
- def javaToPython(jRDD: JavaRDD[Any]): JavaRDD[Array[Byte]] = {
- jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
- }
-
- /**
- * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
- */
- def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
- pyRDD.rdd.mapPartitions { iter =>
- SerDeUtil.initialize()
- val unpickle = new Unpickler
- iter.flatMap { row =>
- val obj = unpickle.loads(row)
- if (batched) {
- obj.asInstanceOf[JArrayList[_]].asScala
- } else {
- Seq(obj)
- }
- }
- }.toJavaRDD()
- }
}
private
@@ -903,3 +807,49 @@ private class PythonAccumulatorParam(@transient serverHost: String, serverPort:
}
}
}
+
+/**
+ * An Wrapper for Python Broadcast, which is written into disk by Python. It also will
+ * write the data into disk after deserialization, then Python can read it from disks.
+ */
+private[spark] class PythonBroadcast(@transient var path: String) extends Serializable {
+
+ /**
+ * Read data from disks, then copy it to `out`
+ */
+ private def writeObject(out: ObjectOutputStream): Unit = Utils.tryOrIOException {
+ val in = new FileInputStream(new File(path))
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ in.close()
+ }
+ }
+
+ /**
+ * Write data into disk, using randomly generated name.
+ */
+ private def readObject(in: ObjectInputStream): Unit = Utils.tryOrIOException {
+ val dir = new File(Utils.getLocalDir(SparkEnv.get.conf))
+ val file = File.createTempFile("broadcast", "", dir)
+ path = file.getAbsolutePath
+ val out = new FileOutputStream(file)
+ try {
+ Utils.copyStream(in, out)
+ } finally {
+ out.close()
+ }
+ }
+
+ /**
+ * Delete the file once the object is GCed.
+ */
+ override def finalize() {
+ if (!path.isEmpty) {
+ val file = new File(path)
+ if (file.exists()) {
+ file.delete()
+ }
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
index ebdc3533e0992..a4153aaa926f8 100644
--- a/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/SerDeUtil.scala
@@ -18,8 +18,13 @@
package org.apache.spark.api.python
import java.nio.ByteOrder
+import java.util.{ArrayList => JArrayList}
+
+import org.apache.spark.api.java.JavaRDD
import scala.collection.JavaConversions._
+import scala.collection.JavaConverters._
+import scala.collection.mutable
import scala.util.Failure
import scala.util.Try
@@ -89,6 +94,73 @@ private[spark] object SerDeUtil extends Logging {
}
initialize()
+
+ /**
+ * Convert an RDD of Java objects to Array (no recursive conversions).
+ * It is only used by pyspark.sql.
+ */
+ def toJavaArray(jrdd: JavaRDD[Any]): JavaRDD[Array[_]] = {
+ jrdd.rdd.map {
+ case objs: JArrayList[_] =>
+ objs.toArray
+ case obj if obj.getClass.isArray =>
+ obj.asInstanceOf[Array[_]].toArray
+ }.toJavaRDD()
+ }
+
+ /**
+ * Choose batch size based on size of objects
+ */
+ private[spark] class AutoBatchedPickler(iter: Iterator[Any]) extends Iterator[Array[Byte]] {
+ private val pickle = new Pickler()
+ private var batch = 1
+ private val buffer = new mutable.ArrayBuffer[Any]
+
+ override def hasNext: Boolean = iter.hasNext
+
+ override def next(): Array[Byte] = {
+ while (iter.hasNext && buffer.length < batch) {
+ buffer += iter.next()
+ }
+ val bytes = pickle.dumps(buffer.toArray)
+ val size = bytes.length
+ // let 1M < size < 10M
+ if (size < 1024 * 1024) {
+ batch *= 2
+ } else if (size > 1024 * 1024 * 10 && batch > 1) {
+ batch /= 2
+ }
+ buffer.clear()
+ bytes
+ }
+ }
+
+ /**
+ * Convert an RDD of Java objects to an RDD of serialized Python objects, that is usable by
+ * PySpark.
+ */
+ private[spark] def javaToPython(jRDD: JavaRDD[_]): JavaRDD[Array[Byte]] = {
+ jRDD.rdd.mapPartitions { iter => new AutoBatchedPickler(iter) }
+ }
+
+ /**
+ * Convert an RDD of serialized Python objects to RDD of objects, that is usable by PySpark.
+ */
+ def pythonToJava(pyRDD: JavaRDD[Array[Byte]], batched: Boolean): JavaRDD[Any] = {
+ pyRDD.rdd.mapPartitions { iter =>
+ initialize()
+ val unpickle = new Unpickler
+ iter.flatMap { row =>
+ val obj = unpickle.loads(row)
+ if (batched) {
+ obj.asInstanceOf[JArrayList[_]].asScala
+ } else {
+ Seq(obj)
+ }
+ }
+ }.toJavaRDD()
+ }
+
private def checkPickle(t: (Any, Any)): (Boolean, Boolean) = {
val pickle = new Pickler
val kt = Try {
@@ -128,17 +200,18 @@ private[spark] object SerDeUtil extends Logging {
*/
def pairRDDToPython(rdd: RDD[(Any, Any)], batchSize: Int): RDD[Array[Byte]] = {
val (keyFailed, valueFailed) = checkPickle(rdd.first())
+
rdd.mapPartitions { iter =>
- val pickle = new Pickler
val cleaned = iter.map { case (k, v) =>
val key = if (keyFailed) k.toString else k
val value = if (valueFailed) v.toString else v
Array[Any](key, value)
}
- if (batchSize > 1) {
- cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
+ if (batchSize == 0) {
+ new AutoBatchedPickler(cleaned)
} else {
- cleaned.map(pickle.dumps(_))
+ val pickle = new Pickler
+ cleaned.grouped(batchSize).map(batched => pickle.dumps(seqAsJavaList(batched)))
}
}
}
@@ -146,36 +219,22 @@ private[spark] object SerDeUtil extends Logging {
/**
* Convert an RDD of serialized Python tuple (K, V) to RDD[(K, V)].
*/
- def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batchSerialized: Boolean): RDD[(K, V)] = {
+ def pythonToPairRDD[K, V](pyRDD: RDD[Array[Byte]], batched: Boolean): RDD[(K, V)] = {
def isPair(obj: Any): Boolean = {
- Option(obj.getClass.getComponentType).map(!_.isPrimitive).getOrElse(false) &&
+ Option(obj.getClass.getComponentType).exists(!_.isPrimitive) &&
obj.asInstanceOf[Array[_]].length == 2
}
- pyRDD.mapPartitions { iter =>
- initialize()
- val unpickle = new Unpickler
- val unpickled =
- if (batchSerialized) {
- iter.flatMap { batch =>
- unpickle.loads(batch) match {
- case objs: java.util.List[_] => collectionAsScalaIterable(objs)
- case other => throw new SparkException(
- s"Unexpected type ${other.getClass.getName} for batch serialized Python RDD")
- }
- }
- } else {
- iter.map(unpickle.loads(_))
- }
- unpickled.map {
- case obj if isPair(obj) =>
- // we only accept (K, V)
- val arr = obj.asInstanceOf[Array[_]]
- (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
- case other => throw new SparkException(
- s"RDD element of type ${other.getClass.getName} cannot be used")
- }
+
+ val rdd = pythonToJava(pyRDD, batched).rdd
+ rdd.first match {
+ case obj if isPair(obj) =>
+ // we only accept (K, V)
+ case other => throw new SparkException(
+ s"RDD element of type ${other.getClass.getName} cannot be used")
+ }
+ rdd.map { obj =>
+ val arr = obj.asInstanceOf[Array[_]]
+ (arr.head.asInstanceOf[K], arr.last.asInstanceOf[V])
}
}
-
}
-
diff --git a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
index e9ca9166eb4d6..c0cbd28a845be 100644
--- a/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
+++ b/core/src/main/scala/org/apache/spark/api/python/WriteInputFormatTestDataGenerator.scala
@@ -176,11 +176,11 @@ object WriteInputFormatTestDataGenerator {
// Create test data for arbitrary custom writable TestWritable
val testClass = Seq(
- ("1", TestWritable("test1", 123, 54.0)),
- ("2", TestWritable("test2", 456, 8762.3)),
- ("1", TestWritable("test3", 123, 423.1)),
- ("3", TestWritable("test56", 456, 423.5)),
- ("2", TestWritable("test2", 123, 5435.2))
+ ("1", TestWritable("test1", 1, 1.0)),
+ ("2", TestWritable("test2", 2, 2.3)),
+ ("3", TestWritable("test3", 3, 3.1)),
+ ("5", TestWritable("test56", 5, 5.5)),
+ ("4", TestWritable("test4", 4, 4.2))
)
val rdd = sc.parallelize(testClass, numSlices = 2).map{ case (k, v) => (new Text(k), v) }
rdd.saveAsNewAPIHadoopFile(classPath,
diff --git a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
index 87f5cf944ed85..a5ea478f231d7 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/Broadcast.scala
@@ -39,7 +39,7 @@ import scala.reflect.ClassTag
*
* {{{
* scala> val broadcastVar = sc.broadcast(Array(1, 2, 3))
- * broadcastVar: spark.Broadcast[Array[Int]] = spark.Broadcast(b5c40191-a864-4c7d-b9bf-d87e1a4e787c)
+ * broadcastVar: org.apache.spark.broadcast.Broadcast[Array[Int]] = Broadcast(0)
*
* scala> broadcastVar.value
* res0: Array[Int] = Array(1, 2, 3)
diff --git a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
index 7dade04273b08..ea98051532a0a 100644
--- a/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
+++ b/core/src/main/scala/org/apache/spark/broadcast/HttpBroadcast.scala
@@ -151,9 +151,10 @@ private[broadcast] object HttpBroadcast extends Logging {
}
private def createServer(conf: SparkConf) {
- broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf))
+ broadcastDir = Utils.createTempDir(Utils.getLocalDir(conf), "broadcast")
val broadcastPort = conf.getInt("spark.broadcast.port", 0)
- server = new HttpServer(broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
+ server =
+ new HttpServer(conf, broadcastDir, securityManager, broadcastPort, "HTTP broadcast server")
server.start()
serverUri = server.uri
logInfo("Broadcast server started at " + serverUri)
@@ -191,10 +192,12 @@ private[broadcast] object HttpBroadcast extends Logging {
logDebug("broadcast security enabled")
val newuri = Utils.constructURIForAuthentication(new URI(url), securityManager)
uc = newuri.toURL.openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
uc.setAllowUserInteraction(false)
} else {
logDebug("broadcast not using security")
uc = new URL(url).openConnection()
+ uc.setConnectTimeout(httpReadTimeout)
}
val in = {
diff --git a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
index 4e802e02c4149..2e1e52906ceeb 100644
--- a/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/ClientArguments.scala
@@ -75,7 +75,8 @@ private[spark] class ClientArguments(args: Array[String]) {
if (!ClientArguments.isValidJarUrl(_jarUrl)) {
println(s"Jar url '${_jarUrl}' is not in valid format.")
- println(s"Must be a jar file path in URL format (e.g. hdfs://XX.jar, file://XX.jar)")
+ println(s"Must be a jar file path in URL format " +
+ "(e.g. hdfs://host:port/XX.jar, file:///XX.jar)")
printUsageAndExit(-1)
}
@@ -119,7 +120,7 @@ object ClientArguments {
def isValidJarUrl(s: String): Boolean = {
try {
val uri = new URI(s)
- uri.getScheme != null && uri.getAuthority != null && s.endsWith("jar")
+ uri.getScheme != null && uri.getPath != null && uri.getPath.endsWith(".jar")
} catch {
case _: URISyntaxException => false
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
index b9dd8557ee904..243d8edb72ed3 100644
--- a/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/DeployMessage.scala
@@ -88,10 +88,14 @@ private[deploy] object DeployMessages {
case class KillDriver(driverId: String) extends DeployMessage
+ case class ApplicationFinished(id: String)
+
// Worker internal
case object WorkDirCleanup // Sent to Worker actor periodically for cleaning up app folders
+ case object ReregisterWithMaster // used when a worker attempts to reconnect to a master
+
// AppClient to Master
case class RegisterApplication(appDescription: ApplicationDescription)
@@ -173,4 +177,5 @@ private[deploy] object DeployMessages {
// Liveness checks in various places
case object SendHeartbeat
+
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
index af94b05ce3847..039c8719e2867 100644
--- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala
@@ -87,8 +87,8 @@ object PythonRunner {
// Strip the URI scheme from the path
formattedPath =
new URI(formattedPath).getScheme match {
- case Utils.windowsDrive(d) if windows => formattedPath
case null => formattedPath
+ case Utils.windowsDrive(d) if windows => formattedPath
case _ => new URI(formattedPath).getPath
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
index e28eaad8a5180..57f9faf5ddd1d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkHadoopUtil.scala
@@ -17,12 +17,14 @@
package org.apache.spark.deploy
+import java.lang.reflect.Method
import java.security.PrivilegedExceptionAction
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.fs.FileSystem.Statistics
import org.apache.hadoop.mapred.JobConf
+import org.apache.hadoop.mapreduce.{JobContext, TaskAttemptContext}
import org.apache.hadoop.security.Credentials
import org.apache.hadoop.security.UserGroupInformation
@@ -133,14 +135,9 @@ class SparkHadoopUtil extends Logging {
*/
private[spark] def getFSBytesReadOnThreadCallback(path: Path, conf: Configuration)
: Option[() => Long] = {
- val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
- val scheme = qualifiedPath.toUri().getScheme()
- val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
try {
- val threadStats = stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
- val statisticsDataClass =
- Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
- val getBytesReadMethod = statisticsDataClass.getDeclaredMethod("getBytesRead")
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesReadMethod = getFileSystemThreadStatisticsMethod("getBytesRead")
val f = () => threadStats.map(getBytesReadMethod.invoke(_).asInstanceOf[Long]).sum
val baselineBytesRead = f()
Some(() => f() - baselineBytesRead)
@@ -151,6 +148,53 @@ class SparkHadoopUtil extends Logging {
}
}
}
+
+ /**
+ * Returns a function that can be called to find Hadoop FileSystem bytes written. If
+ * getFSBytesWrittenOnThreadCallback is called from thread r at time t, the returned callback will
+ * return the bytes written on r since t. Reflection is required because thread-level FileSystem
+ * statistics are only available as of Hadoop 2.5 (see HADOOP-10688).
+ * Returns None if the required method can't be found.
+ */
+ private[spark] def getFSBytesWrittenOnThreadCallback(path: Path, conf: Configuration)
+ : Option[() => Long] = {
+ try {
+ val threadStats = getFileSystemThreadStatistics(path, conf)
+ val getBytesWrittenMethod = getFileSystemThreadStatisticsMethod("getBytesWritten")
+ val f = () => threadStats.map(getBytesWrittenMethod.invoke(_).asInstanceOf[Long]).sum
+ val baselineBytesWritten = f()
+ Some(() => f() - baselineBytesWritten)
+ } catch {
+ case e: NoSuchMethodException => {
+ logDebug("Couldn't find method for retrieving thread-level FileSystem output data", e)
+ None
+ }
+ }
+ }
+
+ private def getFileSystemThreadStatistics(path: Path, conf: Configuration): Seq[AnyRef] = {
+ val qualifiedPath = path.getFileSystem(conf).makeQualified(path)
+ val scheme = qualifiedPath.toUri().getScheme()
+ val stats = FileSystem.getAllStatistics().filter(_.getScheme().equals(scheme))
+ stats.map(Utils.invoke(classOf[Statistics], _, "getThreadStatistics"))
+ }
+
+ private def getFileSystemThreadStatisticsMethod(methodName: String): Method = {
+ val statisticsDataClass =
+ Class.forName("org.apache.hadoop.fs.FileSystem$Statistics$StatisticsData")
+ statisticsDataClass.getDeclaredMethod(methodName)
+ }
+
+ /**
+ * Using reflection to get the Configuration from JobContext/TaskAttemptContext. If we directly
+ * call `JobContext/TaskAttemptContext.getConfiguration`, it will generate different byte codes
+ * for Hadoop 1.+ and Hadoop 2.+ because JobContext/TaskAttemptContext is class in Hadoop 1.+
+ * while it's interface in Hadoop 2.+.
+ */
+ def getConfigurationFromJobContext(context: JobContext): Configuration = {
+ val method = context.getClass.getMethod("getConfiguration")
+ method.invoke(context).asInstanceOf[Configuration]
+ }
}
object SparkHadoopUtil {
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
index b43e68e40f791..a36530c8a1e73 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmit.scala
@@ -279,6 +279,11 @@ object SparkSubmit {
sysProps.getOrElseUpdate(k, v)
}
+ // Ignore invalid spark.driver.host in cluster modes.
+ if (deployMode == CLUSTER) {
+ sysProps -= ("spark.driver.host")
+ }
+
// Resolve paths in certain spark properties
val pathConfigs = Seq(
"spark.jars",
@@ -340,7 +345,7 @@ object SparkSubmit {
e.printStackTrace(printStream)
if (childMainClass.contains("thriftserver")) {
println(s"Failed to load main class $childMainClass.")
- println("You need to build Spark with -Phive.")
+ println("You need to build Spark with -Phive and -Phive-thriftserver.")
}
System.exit(CLASS_NOT_FOUND_EXIT_STATUS)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
index f0e9ee67f6a67..1faabe91f49a8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitArguments.scala
@@ -17,6 +17,7 @@
package org.apache.spark.deploy
+import java.net.URI
import java.util.jar.JarFile
import scala.collection.mutable.{ArrayBuffer, HashMap}
@@ -120,17 +121,28 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
name = Option(name).orElse(sparkProperties.get("spark.app.name")).orNull
jars = Option(jars).orElse(sparkProperties.get("spark.jars")).orNull
deployMode = Option(deployMode).orElse(env.get("DEPLOY_MODE")).orNull
+ numExecutors = Option(numExecutors)
+ .getOrElse(sparkProperties.get("spark.executor.instances").orNull)
// Try to set main class from JAR if no --class argument is given
if (mainClass == null && !isPython && primaryResource != null) {
- try {
- val jar = new JarFile(primaryResource)
- // Note that this might still return null if no main-class is set; we catch that later
- mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
- } catch {
- case e: Exception =>
- SparkSubmit.printErrorAndExit("Cannot load main class from JAR: " + primaryResource)
- return
+ val uri = new URI(primaryResource)
+ val uriScheme = uri.getScheme()
+
+ uriScheme match {
+ case "file" =>
+ try {
+ val jar = new JarFile(uri.getPath)
+ // Note that this might still return null if no main-class is set; we catch that later
+ mainClass = jar.getManifest.getMainAttributes.getValue("Main-Class")
+ } catch {
+ case e: Exception =>
+ SparkSubmit.printErrorAndExit(s"Cannot load main class from JAR $primaryResource")
+ }
+ case _ =>
+ SparkSubmit.printErrorAndExit(
+ s"Cannot load main class from JAR $primaryResource with URI $uriScheme. " +
+ "Please specify a class through --class.")
}
}
@@ -212,7 +224,10 @@ private[spark] class SparkSubmitArguments(args: Seq[String], env: Map[String, St
""".stripMargin
}
- /** Fill in values by parsing user options. */
+ /**
+ * Fill in values by parsing user options.
+ * NOTE: Any changes here must be reflected in YarnClientSchedulerBackend.
+ */
private def parseOpts(opts: Seq[String]): Unit = {
val EQ_SEPARATED_OPT="""(--[^=]+)=(.+)""".r
diff --git a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
index 2b894a796c8c6..2eab9981845e8 100644
--- a/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/SparkSubmitDriverBootstrapper.scala
@@ -129,6 +129,16 @@ private[spark] object SparkSubmitDriverBootstrapper {
val process = builder.start()
+ // If we kill an app while it's running, its sub-process should be killed too.
+ Runtime.getRuntime().addShutdownHook(new Thread() {
+ override def run() = {
+ if (process != null) {
+ process.destroy()
+ process.waitFor()
+ }
+ }
+ })
+
// Redirect stdout and stderr from the child JVM
val stdoutThread = new RedirectThread(process.getInputStream, System.out, "redirect stdout")
val stderrThread = new RedirectThread(process.getErrorStream, System.err, "redirect stderr")
@@ -139,14 +149,16 @@ private[spark] object SparkSubmitDriverBootstrapper {
// subprocess there already reads directly from our stdin, so we should avoid spawning a
// thread that contends with the subprocess in reading from System.in.
val isWindows = Utils.isWindows
- val isPySparkShell = sys.env.contains("PYSPARK_SHELL")
+ val isSubprocess = sys.env.contains("IS_SUBPROCESS")
if (!isWindows) {
- val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin")
+ val stdinThread = new RedirectThread(System.in, process.getOutputStream, "redirect stdin",
+ propagateEof = true)
stdinThread.start()
- // For the PySpark shell, Spark submit itself runs as a python subprocess, and so this JVM
- // should terminate on broken pipe, which signals that the parent process has exited. In
- // Windows, the termination logic for the PySpark shell is handled in java_gateway.py
- if (isPySparkShell) {
+ // Spark submit (JVM) may run as a subprocess, and so this JVM should terminate on
+ // broken pipe, signaling that the parent process has exited. This is the case if the
+ // application is launched directly from python, as in the PySpark shell. In Windows,
+ // the termination logic is handled in java_gateway.py
+ if (isSubprocess) {
stdinThread.join()
process.destroy()
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
index 98a93d1fcb2a3..4efebcaa350fe 100644
--- a/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/client/AppClient.scala
@@ -134,6 +134,7 @@ private[spark] class AppClient(
val fullId = appId + "/" + id
logInfo("Executor added: %s on %s (%s) with %d cores".format(fullId, workerId, hostPort,
cores))
+ master ! ExecutorStateChanged(appId, id, ExecutorState.RUNNING, None, None)
listener.executorAdded(fullId, workerId, hostPort, cores, memory)
case ExecutorUpdated(id, state, message, exitStatus) =>
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
index 2d1609b973607..82a54dbfb5330 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/FsHistoryProvider.scala
@@ -29,22 +29,27 @@ import org.apache.spark.scheduler._
import org.apache.spark.ui.SparkUI
import org.apache.spark.util.Utils
+/**
+ * A class that provides application history from event logs stored in the file system.
+ * This provider checks for new finished applications in the background periodically and
+ * renders the history application UI by parsing the associated event logs.
+ */
private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHistoryProvider
with Logging {
+ import FsHistoryProvider._
+
private val NOT_STARTED = ""
// Interval between each check for event log updates
private val UPDATE_INTERVAL_MS = conf.getInt("spark.history.fs.updateInterval",
conf.getInt("spark.history.updateInterval", 10)) * 1000
- private val logDir = conf.get("spark.history.fs.logDirectory", null)
- private val resolvedLogDir = Option(logDir)
- .map { d => Utils.resolveURI(d) }
- .getOrElse { throw new IllegalArgumentException("Logging directory must be specified.") }
+ private val logDir = conf.getOption("spark.history.fs.logDirectory")
+ .map { d => Utils.resolveURI(d).toString }
+ .getOrElse(DEFAULT_LOG_DIR)
- private val fs = Utils.getHadoopFileSystem(resolvedLogDir,
- SparkHadoopUtil.get.newConfiguration(conf))
+ private val fs = Utils.getHadoopFileSystem(logDir, SparkHadoopUtil.get.newConfiguration(conf))
// A timestamp of when the disk was last accessed to check for log updates
private var lastLogCheckTimeMs = -1L
@@ -87,14 +92,17 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
private def initialize() {
// Validate the log directory.
- val path = new Path(resolvedLogDir)
+ val path = new Path(logDir)
if (!fs.exists(path)) {
- throw new IllegalArgumentException(
- "Logging directory specified does not exist: %s".format(resolvedLogDir))
+ var msg = s"Log directory specified does not exist: $logDir."
+ if (logDir == DEFAULT_LOG_DIR) {
+ msg += " Did you configure the correct one through spark.fs.history.logDirectory?"
+ }
+ throw new IllegalArgumentException(msg)
}
if (!fs.getFileStatus(path).isDir) {
throw new IllegalArgumentException(
- "Logging directory specified is not a directory: %s".format(resolvedLogDir))
+ "Logging directory specified is not a directory: %s".format(logDir))
}
checkForLogs()
@@ -134,8 +142,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
}
- override def getConfig(): Map[String, String] =
- Map("Event Log Location" -> resolvedLogDir.toString)
+ override def getConfig(): Map[String, String] = Map("Event log directory" -> logDir.toString)
/**
* Builds the application list based on the current contents of the log directory.
@@ -146,7 +153,7 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
lastLogCheckTimeMs = getMonotonicTimeMs()
logDebug("Checking for logs. Time is now %d.".format(lastLogCheckTimeMs))
try {
- val logStatus = fs.listStatus(new Path(resolvedLogDir))
+ val logStatus = fs.listStatus(new Path(logDir))
val logDirs = if (logStatus != null) logStatus.filter(_.isDir).toSeq else Seq[FileStatus]()
// Load all new logs from the log directory. Only directories that have a modification time
@@ -244,6 +251,10 @@ private[history] class FsHistoryProvider(conf: SparkConf) extends ApplicationHis
}
+private object FsHistoryProvider {
+ val DEFAULT_LOG_DIR = "file:/tmp/spark-events"
+}
+
private class FsApplicationHistoryInfo(
val logDir: String,
id: String,
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
index 0e249e51a77d8..5fdc350cd8512 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryPage.scala
@@ -58,7 +58,13 @@ private[spark] class HistoryPage(parent: HistoryServer) extends WebUIPage("") {
++
appTable
} else {
-
No Completed Applications Found
+
No completed applications found!
++
+
Did you specify the correct logging directory?
+ Please verify your setting of
+ spark.history.fs.logDirectory and whether you have the permissions to
+ access it. It is also possible that your application did not run to
+ completion or did not stop the SparkContext.
+
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
index ce00c0ffd21e0..fa9bfe5426b6c 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServer.scala
@@ -158,11 +158,12 @@ class HistoryServer(
/**
* The recommended way of starting and stopping a HistoryServer is through the scripts
- * start-history-server.sh and stop-history-server.sh. The path to a base log directory
- * is must be specified, while the requested UI port is optional. For example:
+ * start-history-server.sh and stop-history-server.sh. The path to a base log directory,
+ * as well as any other relevant history server configuration, should be specified via
+ * the $SPARK_HISTORY_OPTS environment variable. For example:
*
- * ./sbin/spark-history-server.sh /tmp/spark-events
- * ./sbin/spark-history-server.sh hdfs://1.2.3.4:9000/spark-events
+ * export SPARK_HISTORY_OPTS="-Dspark.history.fs.logDirectory=/tmp/spark-events"
+ * ./sbin/start-history-server.sh
*
* This launches the HistoryServer as a Spark daemon.
*/
diff --git a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
index 5bce32a04d16d..b1270ade9f750 100644
--- a/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/history/HistoryServerArguments.scala
@@ -17,14 +17,13 @@
package org.apache.spark.deploy.history
-import org.apache.spark.SparkConf
+import org.apache.spark.{Logging, SparkConf}
import org.apache.spark.util.Utils
/**
* Command-line parser for the master.
*/
-private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) {
- private var logDir: String = null
+private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]) extends Logging {
private var propertiesFile: String = null
parse(args.toList)
@@ -32,7 +31,8 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
private def parse(args: List[String]): Unit = {
args match {
case ("--dir" | "-d") :: value :: tail =>
- logDir = value
+ logWarning("Setting log directory through the command line is deprecated as of " +
+ "Spark 1.1.0. Please set this through spark.history.fs.logDirectory instead.")
conf.set("spark.history.fs.logDirectory", value)
System.setProperty("spark.history.fs.logDirectory", value)
parse(tail)
@@ -78,9 +78,10 @@ private[spark] class HistoryServerArguments(conf: SparkConf, args: Array[String]
| (default 50)
|FsHistoryProvider options:
|
- | spark.history.fs.logDirectory Directory where app logs are stored (required)
- | spark.history.fs.updateInterval How often to reload log data from storage (in seconds,
- | default 10)
+ | spark.history.fs.logDirectory Directory where app logs are stored
+ | (default: file:/tmp/spark-events)
+ | spark.history.fs.updateInterval How often to reload log data from storage
+ | (in seconds, default: 10)
|""".stripMargin)
System.exit(exitCode)
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
index 2f81d472d7b78..450c1d596d10a 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/Master.scala
@@ -129,6 +129,10 @@ private[spark] class Master(
masterMetricsSystem.registerSource(masterSource)
masterMetricsSystem.start()
applicationMetricsSystem.start()
+ // Attach the master and app metrics servlet handler to the web ui after the metrics systems are
+ // started.
+ masterMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
+ applicationMetricsSystem.getServletHandlers.foreach(webUi.attachHandler)
persistenceEngine = RECOVERY_MODE match {
case "ZOOKEEPER" =>
@@ -636,7 +640,7 @@ private[spark] class Master(
def registerApplication(app: ApplicationInfo): Unit = {
val appAddress = app.driver.path.address
- if (addressToWorker.contains(appAddress)) {
+ if (addressToApp.contains(appAddress)) {
logInfo("Attempted to re-register application at same address: " + appAddress)
return
}
@@ -685,6 +689,11 @@ private[spark] class Master(
}
persistenceEngine.removeApplication(app)
schedule()
+
+ // Tell all workers that the application has finished, so they can clean up any app state.
+ workers.foreach { w =>
+ w.actor ! ApplicationFinished(app.id)
+ }
}
}
diff --git a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
index d86ec1e03e45c..73400c5affb5d 100644
--- a/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/master/ui/MasterWebUI.scala
@@ -41,8 +41,6 @@ class MasterWebUI(val master: Master, requestedPort: Int)
attachPage(new HistoryNotFoundPage(this))
attachPage(new MasterPage(this))
attachHandler(createStaticHandler(MasterWebUI.STATIC_RESOURCE_DIR, "/static"))
- master.masterMetricsSystem.getServletHandlers.foreach(attachHandler)
- master.applicationMetricsSystem.getServletHandlers.foreach(attachHandler)
}
/** Attach a reconstructed UI to this Master UI. Only valid after bind(). */
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
index 8ba6a01bbcb97..acbdf0d8bd7bc 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ExecutorRunner.scala
@@ -47,6 +47,7 @@ private[spark] class ExecutorRunner(
val executorDir: File,
val workerUrl: String,
val conf: SparkConf,
+ val appLocalDirs: Seq[String],
var state: ExecutorState.Value)
extends Logging {
@@ -77,7 +78,7 @@ private[spark] class ExecutorRunner(
/**
* Kill executor process, wait for exit and notify worker to update resource status.
*
- * @param message the exception message which caused the executor's death
+ * @param message the exception message which caused the executor's death
*/
private def killProcess(message: Option[String]) {
var exitCode: Option[Int] = None
@@ -129,6 +130,7 @@ private[spark] class ExecutorRunner(
logInfo("Launch command: " + command.mkString("\"", "\" \"", "\""))
builder.directory(executorDir)
+ builder.environment.put("SPARK_LOCAL_DIRS", appLocalDirs.mkString(","))
// In case we are running this from within the Spark Shell, avoid creating a "scala"
// parent process for the executor command
builder.environment.put("SPARK_LAUNCH_WITH_SCALA", "0")
@@ -144,8 +146,6 @@ private[spark] class ExecutorRunner(
Files.write(header, stderr, UTF_8)
stderrAppender = FileAppender(process.getErrorStream, stderr, conf)
- state = ExecutorState.RUNNING
- worker ! ExecutorStateChanged(appId, execId, state, None, None)
// Wait for it to exit; executor may exit with code 0 (when driver instructs it to shutdown)
// or with nonzero exit code
val exitCode = process.waitFor()
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
new file mode 100644
index 0000000000000..b9798963bab0a
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/StandaloneWorkerShuffleService.scala
@@ -0,0 +1,66 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.deploy.worker
+
+import org.apache.spark.{Logging, SparkConf, SecurityManager}
+import org.apache.spark.network.TransportContext
+import org.apache.spark.network.netty.SparkTransportConf
+import org.apache.spark.network.sasl.SaslRpcHandler
+import org.apache.spark.network.server.TransportServer
+import org.apache.spark.network.shuffle.ExternalShuffleBlockHandler
+
+/**
+ * Provides a server from which Executors can read shuffle files (rather than reading directly from
+ * each other), to provide uninterrupted access to the files in the face of executors being turned
+ * off or killed.
+ *
+ * Optionally requires SASL authentication in order to read. See [[SecurityManager]].
+ */
+private[worker]
+class StandaloneWorkerShuffleService(sparkConf: SparkConf, securityManager: SecurityManager)
+ extends Logging {
+
+ private val enabled = sparkConf.getBoolean("spark.shuffle.service.enabled", false)
+ private val port = sparkConf.getInt("spark.shuffle.service.port", 7337)
+ private val useSasl: Boolean = securityManager.isAuthenticationEnabled()
+
+ private val transportConf = SparkTransportConf.fromSparkConf(sparkConf, numUsableCores = 0)
+ private val blockHandler = new ExternalShuffleBlockHandler(transportConf)
+ private val transportContext: TransportContext = {
+ val handler = if (useSasl) new SaslRpcHandler(blockHandler, securityManager) else blockHandler
+ new TransportContext(transportConf, handler)
+ }
+
+ private var server: TransportServer = _
+
+ /** Starts the external shuffle service if the user has configured us to. */
+ def startIfEnabled() {
+ if (enabled) {
+ require(server == null, "Shuffle server already started")
+ logInfo(s"Starting shuffle service on port $port with useSasl = $useSasl")
+ server = transportContext.createServer(port)
+ }
+ }
+
+ def stop() {
+ if (enabled && server != null) {
+ server.close()
+ server = null
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
index f1f66d0903f1c..86a87ec22235e 100755
--- a/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/Worker.scala
@@ -21,10 +21,9 @@ import java.io.File
import java.io.IOException
import java.text.SimpleDateFormat
import java.util.{UUID, Date}
-import java.util.concurrent.TimeUnit
import scala.collection.JavaConversions._
-import scala.collection.mutable.HashMap
+import scala.collection.mutable.{HashMap, HashSet}
import scala.concurrent.duration._
import scala.language.postfixOps
import scala.util.Random
@@ -110,6 +109,11 @@ private[spark] class Worker(
val finishedExecutors = new HashMap[String, ExecutorRunner]
val drivers = new HashMap[String, DriverRunner]
val finishedDrivers = new HashMap[String, DriverRunner]
+ val appDirectories = new HashMap[String, Seq[String]]
+ val finishedApps = new HashSet[String]
+
+ // The shuffle service is not actually started unless configured.
+ val shuffleService = new StandaloneWorkerShuffleService(conf, securityMgr)
val publicAddress = {
val envVar = System.getenv("SPARK_PUBLIC_DNS")
@@ -154,12 +158,15 @@ private[spark] class Worker(
logInfo("Spark home: " + sparkHome)
createWorkDir()
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
+ shuffleService.startIfEnabled()
webUi = new WorkerWebUI(this, workDir, webUiPort)
webUi.bind()
registerWithMaster()
metricsSystem.registerSource(workerSource)
metricsSystem.start()
+ // Attach the worker metrics servlet handler to the web ui after the metrics system is started.
+ metricsSystem.getServletHandlers.foreach(webUi.attachHandler)
}
def changeMaster(url: String, uiUrl: String) {
@@ -173,6 +180,9 @@ private[spark] class Worker(
throw new SparkException("Invalid spark URL: " + x)
}
connected = true
+ // Cancel any outstanding re-registration attempts because we found a new master
+ registrationRetryTimer.foreach(_.cancel())
+ registrationRetryTimer = None
}
private def tryRegisterAllMasters() {
@@ -183,7 +193,12 @@ private[spark] class Worker(
}
}
- private def retryConnectToMaster() {
+ /**
+ * Re-register with the master because a network failure or a master failure has occurred.
+ * If the re-registration attempt threshold is exceeded, the worker exits with error.
+ * Note that for thread-safety this should only be called from the actor.
+ */
+ private def reregisterWithMaster(): Unit = {
Utils.tryOrExit {
connectionAttemptCount += 1
if (registered) {
@@ -191,12 +206,40 @@ private[spark] class Worker(
registrationRetryTimer = None
} else if (connectionAttemptCount <= TOTAL_REGISTRATION_RETRIES) {
logInfo(s"Retrying connection to master (attempt # $connectionAttemptCount)")
- tryRegisterAllMasters()
+ /**
+ * Re-register with the active master this worker has been communicating with. If there
+ * is none, then it means this worker is still bootstrapping and hasn't established a
+ * connection with a master yet, in which case we should re-register with all masters.
+ *
+ * It is important to re-register only with the active master during failures. Otherwise,
+ * if the worker unconditionally attempts to re-register with all masters, the following
+ * race condition may arise and cause a "duplicate worker" error detailed in SPARK-4592:
+ *
+ * (1) Master A fails and Worker attempts to reconnect to all masters
+ * (2) Master B takes over and notifies Worker
+ * (3) Worker responds by registering with Master B
+ * (4) Meanwhile, Worker's previous reconnection attempt reaches Master B,
+ * causing the same Worker to register with Master B twice
+ *
+ * Instead, if we only register with the known active master, we can assume that the
+ * old master must have died because another master has taken over. Note that this is
+ * still not safe if the old master recovers within this interval, but this is a much
+ * less likely scenario.
+ */
+ if (master != null) {
+ master ! RegisterWorker(
+ workerId, host, port, cores, memory, webUi.boundPort, publicAddress)
+ } else {
+ // We are retrying the initial registration
+ tryRegisterAllMasters()
+ }
+ // We have exceeded the initial registration retry threshold
+ // All retries from now on should use a higher interval
if (connectionAttemptCount == INITIAL_REGISTRATION_RETRIES) {
registrationRetryTimer.foreach(_.cancel())
registrationRetryTimer = Some {
context.system.scheduler.schedule(PROLONGED_REGISTRATION_RETRY_INTERVAL,
- PROLONGED_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ PROLONGED_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
}
}
} else {
@@ -216,7 +259,7 @@ private[spark] class Worker(
connectionAttemptCount = 0
registrationRetryTimer = Some {
context.system.scheduler.schedule(INITIAL_REGISTRATION_RETRY_INTERVAL,
- INITIAL_REGISTRATION_RETRY_INTERVAL)(retryConnectToMaster)
+ INITIAL_REGISTRATION_RETRY_INTERVAL, self, ReregisterWithMaster)
}
case Some(_) =>
logInfo("Not spawning another attempt to register with the master, since there is an" +
@@ -253,7 +296,7 @@ private[spark] class Worker(
val isAppStillRunning = executors.values.map(_.appId).contains(appIdFromDir)
dir.isDirectory && !isAppStillRunning &&
!Utils.doesDirectoryContainAnyNewFiles(dir, APP_DATA_RETENTION_SECS)
- }.foreach { dir =>
+ }.foreach { dir =>
logInfo(s"Removing directory: ${dir.getPath}")
Utils.deleteRecursively(dir)
}
@@ -298,8 +341,19 @@ private[spark] class Worker(
throw new IOException("Failed to create directory " + executorDir)
}
+ // Create local dirs for the executor. These are passed to the executor via the
+ // SPARK_LOCAL_DIRS environment variable, and deleted by the Worker when the
+ // application finishes.
+ val appLocalDirs = appDirectories.get(appId).getOrElse {
+ Utils.getOrCreateLocalRootDirs(conf).map { dir =>
+ Utils.createDirectory(dir).getAbsolutePath()
+ }.toSeq
+ }
+ appDirectories(appId) = appLocalDirs
+
val manager = new ExecutorRunner(appId, execId, appDesc, cores_, memory_,
- self, workerId, host, sparkHome, executorDir, akkaUrl, conf, ExecutorState.LOADING)
+ self, workerId, host, sparkHome, executorDir, akkaUrl, conf, appLocalDirs,
+ ExecutorState.LOADING)
executors(appId + "/" + execId) = manager
manager.start()
coresUsed += cores_
@@ -336,6 +390,7 @@ private[spark] class Worker(
message.map(" message " + _).getOrElse("") +
exitStatus.map(" exitStatus " + _).getOrElse(""))
}
+ maybeCleanupApplication(appId)
}
case KillExecutor(masterUrl, appId, execId) =>
@@ -396,12 +451,18 @@ private[spark] class Worker(
logInfo(s"$x Disassociated !")
masterDisconnected()
- case RequestWorkerState => {
+ case RequestWorkerState =>
sender ! WorkerStateResponse(host, port, workerId, executors.values.toList,
finishedExecutors.values.toList, drivers.values.toList,
finishedDrivers.values.toList, activeMasterUrl, cores, memory,
coresUsed, memoryUsed, activeMasterWebUiUrl)
- }
+
+ case ReregisterWithMaster =>
+ reregisterWithMaster()
+
+ case ApplicationFinished(id) =>
+ finishedApps += id
+ maybeCleanupApplication(id)
}
private def masterDisconnected() {
@@ -410,6 +471,19 @@ private[spark] class Worker(
registerWithMaster()
}
+ private def maybeCleanupApplication(id: String): Unit = {
+ val shouldCleanup = finishedApps.contains(id) && !executors.values.exists(_.appId == id)
+ if (shouldCleanup) {
+ finishedApps -= id
+ appDirectories.remove(id).foreach { dirList =>
+ logInfo(s"Cleaning up local directories for application $id")
+ dirList.foreach { dir =>
+ Utils.deleteRecursively(new File(dir))
+ }
+ }
+ }
+ }
+
def generateWorkerId(): String = {
"worker-%s-%s-%d".format(createDateFormat.format(new Date), host, port)
}
@@ -419,6 +493,7 @@ private[spark] class Worker(
registrationRetryTimer.foreach(_.cancel())
executors.values.foreach(_.kill())
drivers.values.foreach(_.kill())
+ shuffleService.stop()
webUi.stop()
metricsSystem.stop()
}
@@ -441,7 +516,8 @@ private[spark] object Worker extends Logging {
cores: Int,
memory: Int,
masterUrls: Array[String],
- workDir: String, workerNumber: Option[Int] = None): (ActorSystem, Int) = {
+ workDir: String,
+ workerNumber: Option[Int] = None): (ActorSystem, Int) = {
// The LocalSparkCluster runs multiple local sparkWorkerX actor systems
val conf = new SparkConf
diff --git a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
index b07942a9ca729..7ac81a2d87efd 100644
--- a/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
+++ b/core/src/main/scala/org/apache/spark/deploy/worker/ui/WorkerWebUI.scala
@@ -50,7 +50,6 @@ class WorkerWebUI(
attachHandler(createStaticHandler(WorkerWebUI.STATIC_RESOURCE_BASE, "/static"))
attachHandler(createServletHandler("/log",
(request: HttpServletRequest) => logPage.renderLog(request), worker.securityMgr))
- worker.metricsSystem.getServletHandlers.foreach(attachHandler)
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
index 697154d762d41..5f46f3b1f085e 100644
--- a/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/CoarseGrainedExecutorBackend.scala
@@ -57,9 +57,9 @@ private[spark] class CoarseGrainedExecutorBackend(
override def receiveWithLogging = {
case RegisteredExecutor =>
logInfo("Successfully registered with driver")
- // Make this host instead of hostPort ?
val (hostname, _) = Utils.parseHostPort(hostPort)
- executor = new Executor(executorId, hostname, sparkProperties, isLocal = false, actorSystem)
+ executor = new Executor(executorId, hostname, sparkProperties, cores, isLocal = false,
+ actorSystem)
case RegisterExecutorFailed(message) =>
logError("Slave registration failed: " + message)
@@ -131,7 +131,8 @@ private[spark] object CoarseGrainedExecutorBackend extends Logging {
// Create a new ActorSystem using driver's Spark properties to run the backend.
val driverConf = new SparkConf().setAll(props)
val (actorSystem, boundPort) = AkkaUtils.createActorSystem(
- "sparkExecutor", hostname, port, driverConf, new SecurityManager(driverConf))
+ SparkEnv.executorActorSystemName,
+ hostname, port, driverConf, new SecurityManager(driverConf))
// set it
val sparkHostPort = hostname + ":" + boundPort
actorSystem.actorOf(
diff --git a/core/src/main/scala/org/apache/spark/executor/Executor.scala b/core/src/main/scala/org/apache/spark/executor/Executor.scala
index e24a15f015e1c..eaf0c82d52996 100644
--- a/core/src/main/scala/org/apache/spark/executor/Executor.scala
+++ b/core/src/main/scala/org/apache/spark/executor/Executor.scala
@@ -26,7 +26,7 @@ import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
-import akka.actor.ActorSystem
+import akka.actor.{Props, ActorSystem}
import org.apache.spark._
import org.apache.spark.deploy.SparkHadoopUtil
@@ -41,12 +41,16 @@ import org.apache.spark.util.{SparkUncaughtExceptionHandler, AkkaUtils, Utils}
*/
private[spark] class Executor(
executorId: String,
- slaveHostname: String,
+ executorHostname: String,
properties: Seq[(String, String)],
+ numCores: Int,
isLocal: Boolean = false,
actorSystem: ActorSystem = null)
extends Logging
{
+
+ logInfo(s"Starting executor ID $executorId on host $executorHostname")
+
// Application dependencies (added through SparkContext) that we've fetched so far on this node.
// Each map holds the master's timestamp for the version of that file or JAR we got.
private val currentFiles: HashMap[String, Long] = new HashMap[String, Long]()
@@ -57,12 +61,12 @@ private[spark] class Executor(
@volatile private var isStopped = false
// No ip or host:port - just hostname
- Utils.checkHost(slaveHostname, "Expected executed slave to be a hostname")
+ Utils.checkHost(executorHostname, "Expected executed slave to be a hostname")
// must not have port specified.
- assert (0 == Utils.parseHostPort(slaveHostname)._2)
+ assert (0 == Utils.parseHostPort(executorHostname)._2)
// Make sure the local hostname we report matches the cluster scheduler's name for this host
- Utils.setCustomHostname(slaveHostname)
+ Utils.setCustomHostname(executorHostname)
// Set spark.* properties from executor arg
val conf = new SparkConf(true)
@@ -83,15 +87,20 @@ private[spark] class Executor(
if (!isLocal) {
val port = conf.getInt("spark.executor.port", 0)
val _env = SparkEnv.createExecutorEnv(
- conf, executorId, slaveHostname, port, isLocal, actorSystem)
+ conf, executorId, executorHostname, port, numCores, isLocal, actorSystem)
SparkEnv.set(_env)
_env.metricsSystem.registerSource(executorSource)
+ _env.blockManager.initialize(conf.getAppId)
_env
} else {
SparkEnv.get
}
}
+ // Create an actor for receiving RPCs from the driver
+ private val executorActor = env.actorSystem.actorOf(
+ Props(new ExecutorActor(executorId)), "ExecutorActor")
+
// Create our ClassLoader
// do this after SparkEnv creation so can access the SecurityManager
private val urlClassLoader = createClassLoader()
@@ -131,6 +140,7 @@ private[spark] class Executor(
def stop() {
env.metricsSystem.report()
+ env.actorSystem.stop(executorActor)
isStopped = true
threadPool.shutdown()
if (!isLocal) {
@@ -138,6 +148,8 @@ private[spark] class Executor(
}
}
+ private def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
+
class TaskRunner(
execBackend: ExecutorBackend, val taskId: Long, taskName: String, serializedTask: ByteBuffer)
extends Runnable {
@@ -145,6 +157,7 @@ private[spark] class Executor(
@volatile private var killed = false
@volatile var task: Task[Any] = _
@volatile var attemptedTask: Option[Task[Any]] = None
+ @volatile var startGCTime: Long = _
def kill(interruptThread: Boolean) {
logInfo(s"Executor is trying to kill $taskName (TID $taskId)")
@@ -155,17 +168,15 @@ private[spark] class Executor(
}
override def run() {
- val startTime = System.currentTimeMillis()
+ val deserializeStartTime = System.currentTimeMillis()
Thread.currentThread.setContextClassLoader(replClassLoader)
val ser = SparkEnv.get.closureSerializer.newInstance()
logInfo(s"Running $taskName (TID $taskId)")
execBackend.statusUpdate(taskId, TaskState.RUNNING, EMPTY_BYTE_BUFFER)
var taskStart: Long = 0
- def gcTime = ManagementFactory.getGarbageCollectorMXBeans.map(_.getCollectionTime).sum
- val startGCTime = gcTime
+ startGCTime = gcTime
try {
- Accumulators.clear()
val (taskFiles, taskJars, taskBytes) = Task.deserializeWithDependencies(serializedTask)
updateDependencies(taskFiles, taskJars)
task = ser.deserialize[Task[Any]](taskBytes, Thread.currentThread.getContextClassLoader)
@@ -200,7 +211,7 @@ private[spark] class Executor(
val afterSerialization = System.currentTimeMillis()
for (m <- task.metrics) {
- m.executorDeserializeTime = taskStart - startTime
+ m.executorDeserializeTime = taskStart - deserializeStartTime
m.executorRunTime = taskFinish - taskStart
m.jvmGCTime = gcTime - startGCTime
m.resultSerializationTime = afterSerialization - beforeSerialization
@@ -214,7 +225,7 @@ private[spark] class Executor(
// directSend = sending directly back to the driver
val serializedResult = {
- if (resultSize > maxResultSize) {
+ if (maxResultSize > 0 && resultSize > maxResultSize) {
logWarning(s"Finished $taskName (TID $taskId). Result is larger than maxResultSize " +
s"(${Utils.bytesToString(resultSize)} > ${Utils.bytesToString(maxResultSize)}), " +
s"dropping it.")
@@ -257,7 +268,7 @@ private[spark] class Executor(
m.executorRunTime = serviceTime
m.jvmGCTime = gcTime - startGCTime
}
- val reason = ExceptionFailure(t.getClass.getName, t.getMessage, t.getStackTrace, metrics)
+ val reason = new ExceptionFailure(t, metrics)
execBackend.statusUpdate(taskId, TaskState.FAILED, ser.serialize(reason))
// Don't forcibly exit unless the exception was inherently fatal, to avoid
@@ -271,6 +282,8 @@ private[spark] class Executor(
env.shuffleMemoryManager.releaseMemoryForThisThread()
// Release memory used by this thread for unrolling blocks
env.blockManager.memoryStore.releaseUnrollMemoryForThisThread()
+ // Release memory used by this thread for accumulators
+ Accumulators.clear()
runningTasks.remove(taskId)
}
}
@@ -327,7 +340,7 @@ private[spark] class Executor(
* SparkContext. Also adds any new JARs we fetched to the class loader.
*/
private def updateDependencies(newFiles: HashMap[String, Long], newJars: HashMap[String, Long]) {
- val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
+ lazy val hadoopConf = SparkHadoopUtil.get.newConfiguration(conf)
synchronized {
// Fetch missing dependencies
for ((name, timestamp) <- newFiles if currentFiles.getOrElse(name, -1L) < timestamp) {
@@ -368,10 +381,13 @@ private[spark] class Executor(
while (!isStopped) {
val tasksMetrics = new ArrayBuffer[(Long, TaskMetrics)]()
+ val curGCTime = gcTime
+
for (taskRunner <- runningTasks.values()) {
if (!taskRunner.attemptedTask.isEmpty) {
Option(taskRunner.task).flatMap(_.metrics).foreach { metrics =>
metrics.updateShuffleReadMetrics
+ metrics.jvmGCTime = curGCTime - taskRunner.startGCTime
if (isLocal) {
// JobProgressListener will hold an reference of it during
// onExecutorMetricsUpdate(), then JobProgressListener can not see
diff --git a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
similarity index 59%
rename from graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
rename to core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
index 49b2704390fea..41925f7e97e84 100644
--- a/graphx/src/test/scala/org/apache/spark/graphx/impl/EdgeTripletIteratorSuite.scala
+++ b/core/src/main/scala/org/apache/spark/executor/ExecutorActor.scala
@@ -15,23 +15,27 @@
* limitations under the License.
*/
-package org.apache.spark.graphx.impl
+package org.apache.spark.executor
-import scala.reflect.ClassTag
-import scala.util.Random
+import akka.actor.Actor
+import org.apache.spark.Logging
-import org.scalatest.FunSuite
+import org.apache.spark.util.{Utils, ActorLogReceive}
-import org.apache.spark.graphx._
+/**
+ * Driver -> Executor message to trigger a thread dump.
+ */
+private[spark] case object TriggerThreadDump
+
+/**
+ * Actor that runs inside of executors to enable driver -> executor RPC.
+ */
+private[spark]
+class ExecutorActor(executorId: String) extends Actor with ActorLogReceive with Logging {
-class EdgeTripletIteratorSuite extends FunSuite {
- test("iterator.toList") {
- val builder = new EdgePartitionBuilder[Int, Int]
- builder.add(1, 2, 0)
- builder.add(1, 3, 0)
- builder.add(1, 4, 0)
- val iter = new EdgeTripletIterator[Int, Int](builder.toEdgePartition, true, true)
- val result = iter.toList.map(et => (et.srcId, et.dstId))
- assert(result === Seq((1, 2), (1, 3), (1, 4)))
+ override def receiveWithLogging = {
+ case TriggerThreadDump =>
+ sender ! Utils.getThreadDump()
}
+
}
diff --git a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
index bca0b152268ad..1c6ac0525428f 100644
--- a/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
+++ b/core/src/main/scala/org/apache/spark/executor/MesosExecutorBackend.scala
@@ -19,8 +19,10 @@ package org.apache.spark.executor
import java.nio.ByteBuffer
+import scala.collection.JavaConversions._
+
import org.apache.mesos.protobuf.ByteString
-import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver, MesosNativeLibrary}
+import org.apache.mesos.{Executor => MesosExecutor, ExecutorDriver, MesosExecutorDriver}
import org.apache.mesos.Protos.{TaskStatus => MesosTaskStatus, _}
import org.apache.spark.{Logging, TaskState}
@@ -50,14 +52,23 @@ private[spark] class MesosExecutorBackend
executorInfo: ExecutorInfo,
frameworkInfo: FrameworkInfo,
slaveInfo: SlaveInfo) {
- logInfo("Registered with Mesos as executor ID " + executorInfo.getExecutorId.getValue)
+
+ // Get num cores for this task from ExecutorInfo, created in MesosSchedulerBackend.
+ val cpusPerTask = executorInfo.getResourcesList
+ .find(_.getName == "cpus")
+ .map(_.getScalar.getValue.toInt)
+ .getOrElse(0)
+ val executorId = executorInfo.getExecutorId.getValue
+
+ logInfo(s"Registered with Mesos as executor ID $executorId with $cpusPerTask cpus")
this.driver = driver
val properties = Utils.deserialize[Array[(String, String)]](executorInfo.getData.toByteArray) ++
Seq[(String, String)](("spark.app.id", frameworkInfo.getId.getValue))
executor = new Executor(
- executorInfo.getExecutorId.getValue,
+ executorId,
slaveInfo.getHostname,
- properties)
+ properties,
+ cpusPerTask)
}
override def launchTask(d: ExecutorDriver, taskInfo: TaskInfo) {
@@ -65,7 +76,9 @@ private[spark] class MesosExecutorBackend
if (executor == null) {
logError("Received launchTask but executor was null")
} else {
- executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ SparkHadoopUtil.get.runAsSparkUser { () =>
+ executor.launchTask(this, taskId, taskInfo.getName, taskInfo.getData.asReadOnlyByteBuffer)
+ }
}
}
@@ -97,11 +110,8 @@ private[spark] class MesosExecutorBackend
private[spark] object MesosExecutorBackend extends Logging {
def main(args: Array[String]) {
SignalLogger.register(log)
- SparkHadoopUtil.get.runAsSparkUser { () =>
- MesosNativeLibrary.load()
- // Create a new Executor and start it running
- val runner = new MesosExecutorBackend()
- new MesosExecutorDriver(runner).run()
- }
+ // Create a new Executor and start it running
+ val runner = new MesosExecutorBackend()
+ new MesosExecutorDriver(runner).run()
}
}
diff --git a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
index 57bc2b40cec44..51b5328cb4c8f 100644
--- a/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
+++ b/core/src/main/scala/org/apache/spark/executor/TaskMetrics.scala
@@ -82,6 +82,12 @@ class TaskMetrics extends Serializable {
*/
var inputMetrics: Option[InputMetrics] = None
+ /**
+ * If this task writes data externally (e.g. to a distributed filesystem), metrics on how much
+ * data was written are stored here.
+ */
+ var outputMetrics: Option[OutputMetrics] = None
+
/**
* If this task reads from shuffle output, metrics on getting shuffle data will be collected here.
* This includes read metrics aggregated over all the task's shuffle dependencies.
@@ -157,6 +163,16 @@ object DataReadMethod extends Enumeration with Serializable {
val Memory, Disk, Hadoop, Network = Value
}
+/**
+ * :: DeveloperApi ::
+ * Method by which output data was written.
+ */
+@DeveloperApi
+object DataWriteMethod extends Enumeration with Serializable {
+ type DataWriteMethod = Value
+ val Hadoop = Value
+}
+
/**
* :: DeveloperApi ::
* Metrics about reading input data.
@@ -169,6 +185,18 @@ case class InputMetrics(readMethod: DataReadMethod.Value) {
var bytesRead: Long = 0L
}
+/**
+ * :: DeveloperApi ::
+ * Metrics about writing output data.
+ */
+@DeveloperApi
+case class OutputMetrics(writeMethod: DataWriteMethod.Value) {
+ /**
+ * Total bytes written
+ */
+ var bytesWritten: Long = 0L
+}
+
/**
* :: DeveloperApi ::
* Metrics pertaining to shuffle data read in a given task.
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
index 89b29af2000c8..c219d21fbefa9 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryInputFormat.scala
@@ -21,6 +21,7 @@ import org.apache.hadoop.fs.Path
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.lib.input.FileInputFormat
import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAttemptContext}
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* Custom Input Format for reading and splitting flat binary files that contain records,
@@ -33,7 +34,7 @@ private[spark] object FixedLengthBinaryInputFormat {
/** Retrieves the record length property from a Hadoop configuration */
def getRecordLength(context: JobContext): Int = {
- context.getConfiguration.get(RECORD_LENGTH_PROPERTY).toInt
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).get(RECORD_LENGTH_PROPERTY).toInt
}
}
diff --git a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
index 5164a74bec4e9..67a96925da019 100644
--- a/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/FixedLengthBinaryRecordReader.scala
@@ -24,6 +24,7 @@ import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.io.{BytesWritable, LongWritable}
import org.apache.hadoop.mapreduce.{InputSplit, RecordReader, TaskAttemptContext}
import org.apache.hadoop.mapreduce.lib.input.FileSplit
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* FixedLengthBinaryRecordReader is returned by FixedLengthBinaryInputFormat.
@@ -82,7 +83,7 @@ private[spark] class FixedLengthBinaryRecordReader
// the actual file we will be reading from
val file = fileSplit.getPath
// job configuration
- val job = context.getConfiguration
+ val job = SparkHadoopUtil.get.getConfigurationFromJobContext(context)
// check compression
val codec = new CompressionCodecFactory(job).getCodec(file)
if (codec != null) {
@@ -115,7 +116,7 @@ private[spark] class FixedLengthBinaryRecordReader
if (currentPosition < splitEnd) {
// setup a buffer to store the record
val buffer = recordValue.getBytes
- fileInputStream.read(buffer, 0, recordLength)
+ fileInputStream.readFully(buffer)
// update our current position
currentPosition = currentPosition + recordLength
// return true
diff --git a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
index 457472547fcbb..593a62b3e3b32 100644
--- a/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
+++ b/core/src/main/scala/org/apache/spark/input/PortableDataStream.scala
@@ -28,6 +28,7 @@ import org.apache.hadoop.mapreduce.{InputSplit, JobContext, RecordReader, TaskAt
import org.apache.hadoop.mapreduce.lib.input.{CombineFileInputFormat, CombineFileRecordReader, CombineFileSplit}
import org.apache.spark.annotation.Experimental
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* A general format for reading whole files in as streams, byte arrays,
@@ -145,7 +146,8 @@ class PortableDataStream(
private val confBytes = {
val baos = new ByteArrayOutputStream()
- context.getConfiguration.write(new DataOutputStream(baos))
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context).
+ write(new DataOutputStream(baos))
baos.toByteArray
}
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
index 183bce3d8d8d3..d3601cca832b2 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileInputFormat.scala
@@ -19,14 +19,13 @@ package org.apache.spark.input
import scala.collection.JavaConversions._
+import org.apache.hadoop.conf.{Configuration, Configurable}
import org.apache.hadoop.fs.Path
import org.apache.hadoop.mapreduce.InputSplit
import org.apache.hadoop.mapreduce.JobContext
import org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
-import org.apache.hadoop.mapreduce.lib.input.CombineFileRecordReader
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
/**
* A [[org.apache.hadoop.mapreduce.lib.input.CombineFileInputFormat CombineFileInputFormat]] for
@@ -34,17 +33,24 @@ import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
* the value is the entire content of file.
*/
-private[spark] class WholeTextFileInputFormat extends CombineFileInputFormat[String, String] {
+private[spark] class WholeTextFileInputFormat
+ extends CombineFileInputFormat[String, String] with Configurable {
+
override protected def isSplitable(context: JobContext, file: Path): Boolean = false
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+
override def createRecordReader(
split: InputSplit,
context: TaskAttemptContext): RecordReader[String, String] = {
- new CombineFileRecordReader[String, String](
- split.asInstanceOf[CombineFileSplit],
- context,
- classOf[WholeTextFileRecordReader])
+ val reader = new WholeCombineFileRecordReader(split, context)
+ reader.setConf(conf)
+ reader
}
/**
diff --git a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
index 3564ab2e2a162..4fa84b69aabbc 100644
--- a/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
+++ b/core/src/main/scala/org/apache/spark/input/WholeTextFileRecordReader.scala
@@ -17,13 +17,16 @@
package org.apache.spark.input
+import org.apache.hadoop.conf.{Configuration, Configurable}
import com.google.common.io.{ByteStreams, Closeables}
import org.apache.hadoop.io.Text
+import org.apache.hadoop.io.compress.CompressionCodecFactory
import org.apache.hadoop.mapreduce.InputSplit
-import org.apache.hadoop.mapreduce.lib.input.CombineFileSplit
+import org.apache.hadoop.mapreduce.lib.input.{CombineFileSplit, CombineFileRecordReader}
import org.apache.hadoop.mapreduce.RecordReader
import org.apache.hadoop.mapreduce.TaskAttemptContext
+import org.apache.spark.deploy.SparkHadoopUtil
/**
* A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
@@ -34,10 +37,17 @@ private[spark] class WholeTextFileRecordReader(
split: CombineFileSplit,
context: TaskAttemptContext,
index: Integer)
- extends RecordReader[String, String] {
+ extends RecordReader[String, String] with Configurable {
+
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
private[this] val path = split.getPath(index)
- private[this] val fs = path.getFileSystem(context.getConfiguration)
+ private[this] val fs = path.getFileSystem(
+ SparkHadoopUtil.get.getConfigurationFromJobContext(context))
// True means the current file has been processed, then skip it.
private[this] var processed = false
@@ -57,8 +67,16 @@ private[spark] class WholeTextFileRecordReader(
override def nextKeyValue(): Boolean = {
if (!processed) {
+ val conf = new Configuration
+ val factory = new CompressionCodecFactory(conf)
+ val codec = factory.getCodec(path) // infers from file ext.
val fileIn = fs.open(path)
- val innerBuffer = ByteStreams.toByteArray(fileIn)
+ val innerBuffer = if (codec != null) {
+ ByteStreams.toByteArray(codec.createInputStream(fileIn))
+ } else {
+ ByteStreams.toByteArray(fileIn)
+ }
+
value = new Text(innerBuffer).toString
Closeables.close(fileIn, false)
processed = true
@@ -68,3 +86,33 @@ private[spark] class WholeTextFileRecordReader(
}
}
}
+
+
+/**
+ * A [[org.apache.hadoop.mapreduce.RecordReader RecordReader]] for reading a single whole text file
+ * out in a key-value pair, where the key is the file path and the value is the entire content of
+ * the file.
+ */
+private[spark] class WholeCombineFileRecordReader(
+ split: InputSplit,
+ context: TaskAttemptContext)
+ extends CombineFileRecordReader[String, String](
+ split.asInstanceOf[CombineFileSplit],
+ context,
+ classOf[WholeTextFileRecordReader]
+ ) with Configurable {
+
+ private var conf: Configuration = _
+ def setConf(c: Configuration) {
+ conf = c
+ }
+ def getConf: Configuration = conf
+
+ override def initNextRecordReader(): Boolean = {
+ val r = super.initNextRecordReader()
+ if (r) {
+ this.curReader.asInstanceOf[WholeTextFileRecordReader].setConf(conf)
+ }
+ r
+ }
+}
diff --git a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
similarity index 79%
rename from core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
rename to core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
index 0c47afae54c8b..21b782edd2a9e 100644
--- a/core/src/main/scala/org/apache/hadoop/mapred/SparkHadoopMapRedUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapred/SparkHadoopMapRedUtil.scala
@@ -15,15 +15,24 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapred
+package org.apache.spark.mapred
-private[apache]
+import java.lang.reflect.Modifier
+
+import org.apache.hadoop.mapred.{TaskAttemptID, JobID, JobConf, JobContext, TaskAttemptContext}
+
+private[spark]
trait SparkHadoopMapRedUtil {
def newJobContext(conf: JobConf, jobId: JobID): JobContext = {
val klass = firstAvailableClass("org.apache.hadoop.mapred.JobContextImpl",
"org.apache.hadoop.mapred.JobContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf],
classOf[org.apache.hadoop.mapreduce.JobID])
+ // In Hadoop 1.0.x, JobContext is an interface, and JobContextImpl is package private.
+ // Make it accessible if it's not in order to access it.
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, jobId).asInstanceOf[JobContext]
}
@@ -31,6 +40,10 @@ trait SparkHadoopMapRedUtil {
val klass = firstAvailableClass("org.apache.hadoop.mapred.TaskAttemptContextImpl",
"org.apache.hadoop.mapred.TaskAttemptContext")
val ctor = klass.getDeclaredConstructor(classOf[JobConf], classOf[TaskAttemptID])
+ // See above
+ if (!Modifier.isPublic(ctor.getModifiers)) {
+ ctor.setAccessible(true)
+ }
ctor.newInstance(conf, attemptId).asInstanceOf[TaskAttemptContext]
}
diff --git a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
similarity index 96%
rename from core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
rename to core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
index 1fca5729c6092..3340673f91156 100644
--- a/core/src/main/scala/org/apache/hadoop/mapreduce/SparkHadoopMapReduceUtil.scala
+++ b/core/src/main/scala/org/apache/spark/mapreduce/SparkHadoopMapReduceUtil.scala
@@ -15,13 +15,14 @@
* limitations under the License.
*/
-package org.apache.hadoop.mapreduce
+package org.apache.spark.mapreduce
import java.lang.{Boolean => JBoolean, Integer => JInteger}
import org.apache.hadoop.conf.Configuration
+import org.apache.hadoop.mapreduce.{JobContext, JobID, TaskAttemptContext, TaskAttemptID}
-private[apache]
+private[spark]
trait SparkHadoopMapReduceUtil {
def newJobContext(conf: Configuration, jobId: JobID): JobContext = {
val klass = firstAvailableClass(
diff --git a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
index 5dd67b0cbf683..83e8eb71260eb 100644
--- a/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
+++ b/core/src/main/scala/org/apache/spark/metrics/MetricsSystem.scala
@@ -76,22 +76,36 @@ private[spark] class MetricsSystem private (
private val sources = new mutable.ArrayBuffer[Source]
private val registry = new MetricRegistry()
+ private var running: Boolean = false
+
// Treat MetricsServlet as a special sink as it should be exposed to add handlers to web ui
private var metricsServlet: Option[MetricsServlet] = None
- /** Get any UI handlers used by this metrics system. */
- def getServletHandlers = metricsServlet.map(_.getHandlers).getOrElse(Array())
+ /**
+ * Get any UI handlers used by this metrics system; can only be called after start().
+ */
+ def getServletHandlers = {
+ require(running, "Can only call getServletHandlers on a running MetricsSystem")
+ metricsServlet.map(_.getHandlers).getOrElse(Array())
+ }
metricsConfig.initialize()
def start() {
+ require(!running, "Attempting to start a MetricsSystem that is already running")
+ running = true
registerSources()
registerSinks()
sinks.foreach(_.start)
}
def stop() {
- sinks.foreach(_.stop)
+ if (running) {
+ sinks.foreach(_.stop)
+ } else {
+ logWarning("Stopping a MetricsSystem that is not running")
+ }
+ running = false
}
def report() {
@@ -107,7 +121,7 @@ private[spark] class MetricsSystem private (
* @return An unique metric name for each combination of
* application, executor/driver and metric source.
*/
- def buildRegistryName(source: Source): String = {
+ private[spark] def buildRegistryName(source: Source): String = {
val appId = conf.getOption("spark.app.id")
val executorId = conf.getOption("spark.executor.id")
val defaultName = MetricRegistry.name(source.sourceName)
@@ -116,8 +130,8 @@ private[spark] class MetricsSystem private (
if (appId.isDefined && executorId.isDefined) {
MetricRegistry.name(appId.get, executorId.get, source.sourceName)
} else {
- // Only Driver and Executor are set spark.app.id and spark.executor.id.
- // For instance, Master and Worker are not related to a specific application.
+ // Only Driver and Executor set spark.app.id and spark.executor.id.
+ // Other instance types, e.g. Master and Worker, are not related to a specific application.
val warningMsg = s"Using default name $defaultName for source because %s is not set."
if (appId.isEmpty) { logWarning(warningMsg.format("spark.app.id")) }
if (executorId.isEmpty) { logWarning(warningMsg.format("spark.executor.id")) }
@@ -144,7 +158,7 @@ private[spark] class MetricsSystem private (
})
}
- def registerSources() {
+ private def registerSources() {
val instConfig = metricsConfig.getInstance(instance)
val sourceConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SOURCE_REGEX)
@@ -160,7 +174,7 @@ private[spark] class MetricsSystem private (
}
}
- def registerSinks() {
+ private def registerSinks() {
val instConfig = metricsConfig.getInstance(instance)
val sinkConfigs = metricsConfig.subProperties(instConfig, MetricsSystem.SINK_REGEX)
diff --git a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
index 210a581db466e..dcbda5a8515dd 100644
--- a/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/BlockTransferService.scala
@@ -73,6 +73,7 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit]
@@ -110,9 +111,10 @@ abstract class BlockTransferService extends ShuffleClient with Closeable with Lo
def uploadBlockSync(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Unit = {
- Await.result(uploadBlock(hostname, port, blockId, blockData, level), Duration.Inf)
+ Await.result(uploadBlock(hostname, port, execId, blockId, blockData, level), Duration.Inf)
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
index 1950e7bd634ee..b089da8596e2b 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockRpcServer.scala
@@ -26,18 +26,10 @@ import org.apache.spark.network.BlockDataManager
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.client.{RpcResponseCallback, TransportClient}
import org.apache.spark.network.server.{OneForOneStreamManager, RpcHandler, StreamManager}
-import org.apache.spark.network.shuffle.ShuffleStreamHandle
+import org.apache.spark.network.shuffle.protocol.{BlockTransferMessage, OpenBlocks, StreamHandle, UploadBlock}
import org.apache.spark.serializer.Serializer
import org.apache.spark.storage.{BlockId, StorageLevel}
-object NettyMessages {
- /** Request to read a set of blocks. Returns [[ShuffleStreamHandle]] to identify the stream. */
- case class OpenBlocks(blockIds: Seq[BlockId])
-
- /** Request to upload a block with a certain StorageLevel. Returns nothing (empty byte array). */
- case class UploadBlock(blockId: BlockId, blockData: Array[Byte], level: StorageLevel)
-}
-
/**
* Serves requests to open blocks by simply registering one chunk per block requested.
* Handles opening and uploading arbitrary BlockManager blocks.
@@ -50,28 +42,29 @@ class NettyBlockRpcServer(
blockManager: BlockDataManager)
extends RpcHandler with Logging {
- import NettyMessages._
-
private val streamManager = new OneForOneStreamManager()
override def receive(
client: TransportClient,
messageBytes: Array[Byte],
responseContext: RpcResponseCallback): Unit = {
- val ser = serializer.newInstance()
- val message = ser.deserialize[AnyRef](ByteBuffer.wrap(messageBytes))
+ val message = BlockTransferMessage.Decoder.fromByteArray(messageBytes)
logTrace(s"Received request: $message")
message match {
- case OpenBlocks(blockIds) =>
- val blocks: Seq[ManagedBuffer] = blockIds.map(blockManager.getBlockData)
+ case openBlocks: OpenBlocks =>
+ val blocks: Seq[ManagedBuffer] =
+ openBlocks.blockIds.map(BlockId.apply).map(blockManager.getBlockData)
val streamId = streamManager.registerStream(blocks.iterator)
logTrace(s"Registered streamId $streamId with ${blocks.size} buffers")
- responseContext.onSuccess(
- ser.serialize(new ShuffleStreamHandle(streamId, blocks.size)).array())
+ responseContext.onSuccess(new StreamHandle(streamId, blocks.size).toByteArray)
- case UploadBlock(blockId, blockData, level) =>
- blockManager.putBlockData(blockId, new NioManagedBuffer(ByteBuffer.wrap(blockData)), level)
+ case uploadBlock: UploadBlock =>
+ // StorageLevel is serialized as bytes using our JavaSerializer.
+ val level: StorageLevel =
+ serializer.newInstance().deserialize(ByteBuffer.wrap(uploadBlock.metadata))
+ val data = new NioManagedBuffer(ByteBuffer.wrap(uploadBlock.blockData))
+ blockManager.putBlockData(BlockId(uploadBlock.blockId), data, level)
responseContext.onSuccess(new Array[Byte](0))
}
}
diff --git a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
index 1c4327cf13b51..3f0950dae1f24 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/NettyBlockTransferService.scala
@@ -17,15 +17,17 @@
package org.apache.spark.network.netty
+import scala.collection.JavaConversions._
import scala.concurrent.{Future, Promise}
-import org.apache.spark.SparkConf
+import org.apache.spark.{SecurityManager, SparkConf}
import org.apache.spark.network._
import org.apache.spark.network.buffer.ManagedBuffer
-import org.apache.spark.network.client.{RpcResponseCallback, TransportClientFactory}
-import org.apache.spark.network.netty.NettyMessages.{OpenBlocks, UploadBlock}
+import org.apache.spark.network.client.{TransportClientBootstrap, RpcResponseCallback, TransportClientFactory}
+import org.apache.spark.network.sasl.{SaslRpcHandler, SaslClientBootstrap}
import org.apache.spark.network.server._
-import org.apache.spark.network.shuffle.{BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.{RetryingBlockFetcher, BlockFetchingListener, OneForOneBlockFetcher}
+import org.apache.spark.network.shuffle.protocol.UploadBlock
import org.apache.spark.serializer.JavaSerializer
import org.apache.spark.storage.{BlockId, StorageLevel}
import org.apache.spark.util.Utils
@@ -33,19 +35,33 @@ import org.apache.spark.util.Utils
/**
* A BlockTransferService that uses Netty to fetch a set of blocks at at time.
*/
-class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
+class NettyBlockTransferService(conf: SparkConf, securityManager: SecurityManager, numCores: Int)
+ extends BlockTransferService {
+
// TODO: Don't use Java serialization, use a more cross-version compatible serialization format.
- val serializer = new JavaSerializer(conf)
+ private val serializer = new JavaSerializer(conf)
+ private val authEnabled = securityManager.isAuthenticationEnabled()
+ private val transportConf = SparkTransportConf.fromSparkConf(conf, numCores)
private[this] var transportContext: TransportContext = _
private[this] var server: TransportServer = _
private[this] var clientFactory: TransportClientFactory = _
+ private[this] var appId: String = _
override def init(blockDataManager: BlockDataManager): Unit = {
- val rpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
- transportContext = new TransportContext(SparkTransportConf.fromSparkConf(conf), rpcHandler)
- clientFactory = transportContext.createClientFactory()
- server = transportContext.createServer()
+ val (rpcHandler: RpcHandler, bootstrap: Option[TransportClientBootstrap]) = {
+ val nettyRpcHandler = new NettyBlockRpcServer(serializer, blockDataManager)
+ if (!authEnabled) {
+ (nettyRpcHandler, None)
+ } else {
+ (new SaslRpcHandler(nettyRpcHandler, securityManager),
+ Some(new SaslClientBootstrap(transportConf, conf.getAppId, securityManager)))
+ }
+ }
+ transportContext = new TransportContext(transportConf, rpcHandler)
+ clientFactory = transportContext.createClientFactory(bootstrap.toList)
+ server = transportContext.createServer(conf.getInt("spark.blockManager.port", 0))
+ appId = conf.getAppId
logInfo("Server created on " + server.getPort)
}
@@ -57,9 +73,21 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
listener: BlockFetchingListener): Unit = {
logTrace(s"Fetch blocks from $host:$port (executor id $execId)")
try {
- val client = clientFactory.createClient(host, port)
- new OneForOneBlockFetcher(client, blockIds.toArray, listener)
- .start(OpenBlocks(blockIds.map(BlockId.apply)))
+ val blockFetchStarter = new RetryingBlockFetcher.BlockFetchStarter {
+ override def createAndStart(blockIds: Array[String], listener: BlockFetchingListener) {
+ val client = clientFactory.createClient(host, port)
+ new OneForOneBlockFetcher(client, appId, execId, blockIds.toArray, listener).start()
+ }
+ }
+
+ val maxRetries = transportConf.maxIORetries()
+ if (maxRetries > 0) {
+ // Note this Fetcher will correctly handle maxRetries == 0; we avoid it just in case there's
+ // a bug in this code. We should remove the if statement once we're sure of the stability.
+ new RetryingBlockFetcher(transportConf, blockFetchStarter, blockIds, listener).start()
+ } else {
+ blockFetchStarter.createAndStart(blockIds, listener)
+ }
} catch {
case e: Exception =>
logError("Exception while beginning fetchBlocks", e)
@@ -74,12 +102,17 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
override def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel): Future[Unit] = {
val result = Promise[Unit]()
val client = clientFactory.createClient(hostname, port)
+ // StorageLevel is serialized as bytes using our JavaSerializer. Everything else is encoded
+ // using our binary protocol.
+ val levelBytes = serializer.newInstance().serialize(level).array()
+
// Convert or copy nio buffer into array in order to serialize it.
val nioBuffer = blockData.nioByteBuffer()
val array = if (nioBuffer.hasArray) {
@@ -90,8 +123,7 @@ class NettyBlockTransferService(conf: SparkConf) extends BlockTransferService {
data
}
- val ser = serializer.newInstance()
- client.sendRpc(ser.serialize(new UploadBlock(blockId, array, level)).array(),
+ client.sendRpc(new UploadBlock(appId, execId, blockId.toString, levelBytes, array).toByteArray,
new RpcResponseCallback {
override def onSuccess(response: Array[Byte]): Unit = {
logTrace(s"Successfully uploaded block $blockId")
diff --git a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
index 9fa4fa77b8817..cef203006d685 100644
--- a/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
+++ b/core/src/main/scala/org/apache/spark/network/netty/SparkTransportConf.scala
@@ -21,12 +21,53 @@ import org.apache.spark.SparkConf
import org.apache.spark.network.util.{TransportConf, ConfigProvider}
/**
- * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * Provides a utility for transforming from a SparkConf inside a Spark JVM (e.g., Executor,
+ * Driver, or a standalone shuffle service) into a TransportConf with details on our environment
+ * like the number of cores that are allocated to this JVM.
*/
object SparkTransportConf {
- def fromSparkConf(conf: SparkConf): TransportConf = {
+ /**
+ * Specifies an upper bound on the number of Netty threads that Spark requires by default.
+ * In practice, only 2-4 cores should be required to transfer roughly 10 Gb/s, and each core
+ * that we use will have an initial overhead of roughly 32 MB of off-heap memory, which comes
+ * at a premium.
+ *
+ * Thus, this value should still retain maximum throughput and reduce wasted off-heap memory
+ * allocation. It can be overridden by setting the number of serverThreads and clientThreads
+ * manually in Spark's configuration.
+ */
+ private val MAX_DEFAULT_NETTY_THREADS = 8
+
+ /**
+ * Utility for creating a [[TransportConf]] from a [[SparkConf]].
+ * @param numUsableCores if nonzero, this will restrict the server and client threads to only
+ * use the given number of cores, rather than all of the machine's cores.
+ * This restriction will only occur if these properties are not already set.
+ */
+ def fromSparkConf(_conf: SparkConf, numUsableCores: Int = 0): TransportConf = {
+ val conf = _conf.clone
+
+ // Specify thread configuration based on our JVM's allocation of cores (rather than necessarily
+ // assuming we have all the machine's cores).
+ // NB: Only set if serverThreads/clientThreads not already set.
+ val numThreads = defaultNumThreads(numUsableCores)
+ conf.set("spark.shuffle.io.serverThreads",
+ conf.get("spark.shuffle.io.serverThreads", numThreads.toString))
+ conf.set("spark.shuffle.io.clientThreads",
+ conf.get("spark.shuffle.io.clientThreads", numThreads.toString))
+
new TransportConf(new ConfigProvider {
override def get(name: String): String = conf.get(name)
})
}
+
+ /**
+ * Returns the default number of threads for both the Netty client and server thread pools.
+ * If numUsableCores is 0, we will use Runtime get an approximate number of available cores.
+ */
+ private def defaultNumThreads(numUsableCores: Int): Int = {
+ val availableCores =
+ if (numUsableCores > 0) numUsableCores else Runtime.getRuntime.availableProcessors()
+ math.min(availableCores, MAX_DEFAULT_NETTY_THREADS)
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
index 4f6f5e235811d..c2d9578be7ebb 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/Connection.scala
@@ -23,12 +23,13 @@ import java.nio.channels._
import java.util.concurrent.ConcurrentLinkedQueue
import java.util.LinkedList
-import org.apache.spark._
-
import scala.collection.JavaConversions._
import scala.collection.mutable.{ArrayBuffer, HashMap}
import scala.util.control.NonFatal
+import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
+
private[nio]
abstract class Connection(val channel: SocketChannel, val selector: Selector,
val socketRemoteConnectionManagerId: ConnectionManagerId, val connectionId: ConnectionId,
diff --git a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
index 8408b75bb4d65..302b496b8a849 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/ConnectionManager.scala
@@ -18,13 +18,13 @@
package org.apache.spark.network.nio
import java.io.IOException
+import java.lang.ref.WeakReference
import java.net._
import java.nio._
import java.nio.channels._
import java.nio.channels.spi._
import java.util.concurrent.atomic.AtomicInteger
import java.util.concurrent.{LinkedBlockingDeque, ThreadPoolExecutor, TimeUnit}
-import java.util.{Timer, TimerTask}
import scala.collection.mutable.{ArrayBuffer, HashMap, HashSet, SynchronizedMap, SynchronizedQueue}
import scala.concurrent.duration._
@@ -32,8 +32,10 @@ import scala.concurrent.{Await, ExecutionContext, Future, Promise}
import scala.language.postfixOps
import com.google.common.base.Charsets.UTF_8
+import io.netty.util.{Timeout, TimerTask, HashedWheelTimer}
import org.apache.spark._
+import org.apache.spark.network.sasl.{SparkSaslClient, SparkSaslServer}
import org.apache.spark.util.Utils
import scala.util.Try
@@ -76,7 +78,8 @@ private[nio] class ConnectionManager(
}
private val selector = SelectorProvider.provider.openSelector()
- private val ackTimeoutMonitor = new Timer("AckTimeoutMonitor", true)
+ private val ackTimeoutMonitor =
+ new HashedWheelTimer(Utils.namedThreadFactory("AckTimeoutMonitor"))
private val ackTimeout = conf.getInt("spark.core.connection.ack.wait.timeout", 60)
@@ -138,7 +141,10 @@ private[nio] class ConnectionManager(
new HashMap[SelectionKey, Connection] with SynchronizedMap[SelectionKey, Connection]
private val connectionsById = new HashMap[ConnectionManagerId, SendingConnection]
with SynchronizedMap[ConnectionManagerId, SendingConnection]
- private val messageStatuses = new HashMap[Int, MessageStatus]
+ // Tracks sent messages for which we are awaiting acknowledgements. Entries are added to this
+ // map when messages are sent and are removed when acknowledgement messages are received or when
+ // acknowledgement timeouts expire
+ private val messageStatuses = new HashMap[Int, MessageStatus] // [MessageId, MessageStatus]
private val keyInterestChangeRequests = new SynchronizedQueue[(SelectionKey, Int)]
private val registerRequests = new SynchronizedQueue[SendingConnection]
@@ -158,7 +164,7 @@ private[nio] class ConnectionManager(
serverChannel.socket.bind(new InetSocketAddress(port))
(serverChannel, serverChannel.socket.getLocalPort)
}
- Utils.startServiceOnPort[ServerSocketChannel](port, startService, name)
+ Utils.startServiceOnPort[ServerSocketChannel](port, startService, conf, name)
serverChannel.register(selector, SelectionKey.OP_ACCEPT)
val id = new ConnectionManagerId(Utils.localHostName, serverChannel.socket.getLocalPort)
@@ -600,7 +606,7 @@ private[nio] class ConnectionManager(
} else {
var replyToken : Array[Byte] = null
try {
- replyToken = waitingConn.sparkSaslClient.saslResponse(securityMsg.getToken)
+ replyToken = waitingConn.sparkSaslClient.response(securityMsg.getToken)
if (waitingConn.isSaslComplete()) {
logDebug("Client sasl completed after evaluate for id: " + waitingConn.connectionId)
connectionsAwaitingSasl -= waitingConn.connectionId
@@ -634,7 +640,7 @@ private[nio] class ConnectionManager(
connection.synchronized {
if (connection.sparkSaslServer == null) {
logDebug("Creating sasl Server")
- connection.sparkSaslServer = new SparkSaslServer(securityManager)
+ connection.sparkSaslServer = new SparkSaslServer(conf.getAppId, securityManager)
}
}
replyToken = connection.sparkSaslServer.response(securityMsg.getToken)
@@ -778,7 +784,7 @@ private[nio] class ConnectionManager(
if (!conn.isSaslComplete()) {
conn.synchronized {
if (conn.sparkSaslClient == null) {
- conn.sparkSaslClient = new SparkSaslClient(securityManager)
+ conn.sparkSaslClient = new SparkSaslClient(conf.getAppId, securityManager)
var firstResponse: Array[Byte] = null
try {
firstResponse = conn.sparkSaslClient.firstToken()
@@ -898,22 +904,41 @@ private[nio] class ConnectionManager(
: Future[Message] = {
val promise = Promise[Message]()
- val timeoutTask = new TimerTask {
- override def run(): Unit = {
+ // It's important that the TimerTask doesn't capture a reference to `message`, which can cause
+ // memory leaks since cancelled TimerTasks won't necessarily be garbage collected until the time
+ // at which they would originally be scheduled to run. Therefore, extract the message id
+ // from outside of the TimerTask closure (see SPARK-4393 for more context).
+ val messageId = message.id
+ // Keep a weak reference to the promise so that the completed promise may be garbage-collected
+ val promiseReference = new WeakReference(promise)
+ val timeoutTask: TimerTask = new TimerTask {
+ override def run(timeout: Timeout): Unit = {
messageStatuses.synchronized {
- messageStatuses.remove(message.id).foreach ( s => {
+ messageStatuses.remove(messageId).foreach { s =>
val e = new IOException("sendMessageReliably failed because ack " +
s"was not received within $ackTimeout sec")
- if (!promise.tryFailure(e)) {
- logWarning("Ignore error because promise is completed", e)
+ val p = promiseReference.get
+ if (p != null) {
+ // Attempt to fail the promise with a Timeout exception
+ if (!p.tryFailure(e)) {
+ // If we reach here, then someone else has already signalled success or failure
+ // on this promise, so log a warning:
+ logError("Ignore error because promise is completed", e)
+ }
+ } else {
+ // The WeakReference was empty, which should never happen because
+ // sendMessageReliably's caller should have a strong reference to promise.future;
+ logError("Promise was garbage collected; this should never happen!", e)
}
- })
+ }
}
}
}
+ val timeoutTaskHandle = ackTimeoutMonitor.newTimeout(timeoutTask, ackTimeout, TimeUnit.SECONDS)
+
val status = new MessageStatus(message, connectionManagerId, s => {
- timeoutTask.cancel()
+ timeoutTaskHandle.cancel()
s match {
case scala.util.Failure(e) =>
// Indicates a failure where we either never sent or never got ACK'd
@@ -942,7 +967,6 @@ private[nio] class ConnectionManager(
messageStatuses += ((message.id, status))
}
- ackTimeoutMonitor.schedule(timeoutTask, ackTimeout * 1000)
sendMessage(connectionManagerId, message)
promise.future
}
@@ -952,7 +976,7 @@ private[nio] class ConnectionManager(
}
def stop() {
- ackTimeoutMonitor.cancel()
+ ackTimeoutMonitor.stop()
selectorThread.interrupt()
selectorThread.join()
selector.close()
diff --git a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
index f56d165daba55..b2aec160635c7 100644
--- a/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
+++ b/core/src/main/scala/org/apache/spark/network/nio/NioBlockTransferService.scala
@@ -137,6 +137,7 @@ final class NioBlockTransferService(conf: SparkConf, securityManager: SecurityMa
override def uploadBlock(
hostname: String,
port: Int,
+ execId: String,
blockId: BlockId,
blockData: ManagedBuffer,
level: StorageLevel)
diff --git a/core/src/main/scala/org/apache/spark/package.scala b/core/src/main/scala/org/apache/spark/package.scala
index e2fc9c649925e..9bb5f5ea37ee8 100644
--- a/core/src/main/scala/org/apache/spark/package.scala
+++ b/core/src/main/scala/org/apache/spark/package.scala
@@ -44,5 +44,5 @@ package org.apache
package object spark {
// For package docs only
- val SPARK_VERSION = "1.2.0-SNAPSHOT"
+ val SPARK_VERSION = "1.2.1"
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
index ffc0a8a6d67eb..70edf191d928a 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoGroupedRDD.scala
@@ -60,7 +60,7 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
* A RDD that cogroups its parents. For each key k in parent RDDs, the resulting RDD contains a
* tuple with the list of values for that key.
*
- * Note: This is an internal API. We recommend users use RDD.coGroup(...) instead of
+ * Note: This is an internal API. We recommend users use RDD.cogroup(...) instead of
* instantiating this directly.
* @param rdds parent RDDs.
@@ -70,8 +70,8 @@ private[spark] class CoGroupPartition(idx: Int, val deps: Array[CoGroupSplitDep]
class CoGroupedRDD[K](@transient var rdds: Seq[RDD[_ <: Product2[K, _]]], part: Partitioner)
extends RDD[(K, Array[Iterable[_]])](rdds.head.context, Nil) {
- // For example, `(k, a) cogroup (k, b)` produces k -> Seq(ArrayBuffer as, ArrayBuffer bs).
- // Each ArrayBuffer is represented as a CoGroup, and the resulting Seq as a CoGroupCombiner.
+ // For example, `(k, a) cogroup (k, b)` produces k -> Array(ArrayBuffer as, ArrayBuffer bs).
+ // Each ArrayBuffer is represented as a CoGroup, and the resulting Array as a CoGroupCombiner.
// CoGroupValue is the intermediate state of each value before being merged in compute.
private type CoGroup = CompactBuffer[Any]
private type CoGroupValue = (Any, Int) // Int is dependency number
diff --git a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
index 9fab1d78abb04..b073eba8a1574 100644
--- a/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/CoalescedRDD.scala
@@ -35,11 +35,10 @@ import org.apache.spark.util.Utils
* @param preferredLocation the preferred location for this partition
*/
private[spark] case class CoalescedRDDPartition(
- index: Int,
- @transient rdd: RDD[_],
- parentsIndices: Array[Int],
- @transient preferredLocation: String = ""
- ) extends Partition {
+ index: Int,
+ @transient rdd: RDD[_],
+ parentsIndices: Array[Int],
+ @transient preferredLocation: Option[String] = None) extends Partition {
var parents: Seq[Partition] = parentsIndices.map(rdd.partitions(_))
@throws(classOf[IOException])
@@ -55,9 +54,10 @@ private[spark] case class CoalescedRDDPartition(
* @return locality of this coalesced partition between 0 and 1
*/
def localFraction: Double = {
- val loc = parents.count(p =>
- rdd.context.getPreferredLocs(rdd, p.index).map(tl => tl.host).contains(preferredLocation))
-
+ val loc = parents.count { p =>
+ val parentPreferredLocations = rdd.context.getPreferredLocs(rdd, p.index).map(_.host)
+ preferredLocation.exists(parentPreferredLocations.contains)
+ }
if (parents.size == 0) 0.0 else (loc.toDouble / parents.size.toDouble)
}
}
@@ -73,9 +73,9 @@ private[spark] case class CoalescedRDDPartition(
* @param balanceSlack used to trade-off balance and locality. 1.0 is all locality, 0 is all balance
*/
private[spark] class CoalescedRDD[T: ClassTag](
- @transient var prev: RDD[T],
- maxPartitions: Int,
- balanceSlack: Double = 0.10)
+ @transient var prev: RDD[T],
+ maxPartitions: Int,
+ balanceSlack: Double = 0.10)
extends RDD[T](prev.context, Nil) { // Nil since we implement getDependencies
override def getPartitions: Array[Partition] = {
@@ -113,7 +113,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
* @return the machine most preferred by split
*/
override def getPreferredLocations(partition: Partition): Seq[String] = {
- List(partition.asInstanceOf[CoalescedRDDPartition].preferredLocation)
+ partition.asInstanceOf[CoalescedRDDPartition].preferredLocation.toSeq
}
}
@@ -147,7 +147,7 @@ private[spark] class CoalescedRDD[T: ClassTag](
*
*/
-private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
+private class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanceSlack: Double) {
def compare(o1: PartitionGroup, o2: PartitionGroup): Boolean = o1.size < o2.size
def compare(o1: Option[PartitionGroup], o2: Option[PartitionGroup]): Boolean =
@@ -341,8 +341,14 @@ private[spark] class PartitionCoalescer(maxPartitions: Int, prev: RDD[_], balanc
}
}
-private[spark] case class PartitionGroup(prefLoc: String = "") {
+private case class PartitionGroup(prefLoc: Option[String] = None) {
var arr = mutable.ArrayBuffer[Partition]()
-
def size = arr.size
}
+
+private object PartitionGroup {
+ def apply(prefLoc: String): PartitionGroup = {
+ require(prefLoc != "", "Preferred location must not be empty")
+ PartitionGroup(Some(prefLoc))
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
index 946fb5616d3ec..a157e36e2286e 100644
--- a/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/HadoopRDD.scala
@@ -211,20 +211,11 @@ class HadoopRDD[K, V](
val split = theSplit.asInstanceOf[HadoopPartition]
logInfo("Input split: " + split.inputSplit)
- var reader: RecordReader[K, V] = null
val jobConf = getJobConf()
- val inputFormat = getInputFormat(jobConf)
- HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
- context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
- reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
-
- // Register an on-task-completion callback to close the input stream.
- context.addTaskCompletionListener{ context => closeIfNeeded() }
- val key: K = reader.createKey()
- val value: V = reader.createValue()
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.inputSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.inputSplit.value.asInstanceOf[FileSplit].getPath, jobConf)
@@ -234,6 +225,18 @@ class HadoopRDD[K, V](
if (bytesReadCallback.isDefined) {
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+
+ var reader: RecordReader[K, V] = null
+ val inputFormat = getInputFormat(jobConf)
+ HadoopRDD.addLocalConfiguration(new SimpleDateFormat("yyyyMMddHHmm").format(createTime),
+ context.stageId, theSplit.index, context.attemptId.toInt, jobConf)
+ reader = inputFormat.getRecordReader(split.inputSplit.value, jobConf, Reporter.NULL)
+
+ // Register an on-task-completion callback to close the input stream.
+ context.addTaskCompletionListener{ context => closeIfNeeded() }
+ val key: K = reader.createKey()
+ val value: V = reader.createValue()
+
var recordsSinceMetricsUpdate = 0
override def getNext() = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
index 0e38f224ac81d..642a12c1edf6c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/JdbcRDD.scala
@@ -21,8 +21,11 @@ import java.sql.{Connection, ResultSet}
import scala.reflect.ClassTag
-import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
+import org.apache.spark.api.java.JavaSparkContext.fakeClassTag
+import org.apache.spark.api.java.function.{Function => JFunction}
+import org.apache.spark.api.java.{JavaRDD, JavaSparkContext}
import org.apache.spark.util.NextIterator
+import org.apache.spark.{Logging, Partition, SparkContext, TaskContext}
private[spark] class JdbcPartition(idx: Int, val lower: Long, val upper: Long) extends Partition {
override def index = idx
@@ -125,5 +128,82 @@ object JdbcRDD {
def resultSetToObjectArray(rs: ResultSet): Array[Object] = {
Array.tabulate[Object](rs.getMetaData.getColumnCount)(i => rs.getObject(i + 1))
}
-}
+ trait ConnectionFactory extends Serializable {
+ @throws[Exception]
+ def getConnection: Connection
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results.
+ * For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ * @param mapRow a function from a ResultSet to a single row of the desired result type(s).
+ * This should only call getInt, getString, etc; the RDD takes care of calling next.
+ * The default maps a ResultSet to an array of Object.
+ */
+ def create[T](
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int,
+ mapRow: JFunction[ResultSet, T]): JavaRDD[T] = {
+
+ val jdbcRDD = new JdbcRDD[T](
+ sc.sc,
+ () => connectionFactory.getConnection,
+ sql,
+ lowerBound,
+ upperBound,
+ numPartitions,
+ (resultSet: ResultSet) => mapRow.call(resultSet))(fakeClassTag)
+
+ new JavaRDD[T](jdbcRDD)(fakeClassTag)
+ }
+
+ /**
+ * Create an RDD that executes an SQL query on a JDBC connection and reads results. Each row is
+ * converted into a `Object` array. For usage example, see test case JavaAPISuite.testJavaJdbcRDD.
+ *
+ * @param connectionFactory a factory that returns an open Connection.
+ * The RDD takes care of closing the connection.
+ * @param sql the text of the query.
+ * The query must contain two ? placeholders for parameters used to partition the results.
+ * E.g. "select title, author from books where ? <= id and id <= ?"
+ * @param lowerBound the minimum value of the first placeholder
+ * @param upperBound the maximum value of the second placeholder
+ * The lower and upper bounds are inclusive.
+ * @param numPartitions the number of partitions.
+ * Given a lowerBound of 1, an upperBound of 20, and a numPartitions of 2,
+ * the query would be executed twice, once with (1, 10) and once with (11, 20)
+ */
+ def create(
+ sc: JavaSparkContext,
+ connectionFactory: ConnectionFactory,
+ sql: String,
+ lowerBound: Long,
+ upperBound: Long,
+ numPartitions: Int): JavaRDD[Array[Object]] = {
+
+ val mapRow = new JFunction[ResultSet, Array[Object]] {
+ override def call(resultSet: ResultSet): Array[Object] = {
+ resultSetToObjectArray(resultSet)
+ }
+ }
+
+ create(sc, connectionFactory, sql, lowerBound, upperBound, numPartitions, mapRow)
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
index 6d6b86721ca74..e55d03d391e03 100644
--- a/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/NewHadoopRDD.scala
@@ -35,6 +35,7 @@ import org.apache.spark.Partition
import org.apache.spark.SerializableWritable
import org.apache.spark.{SparkContext, TaskContext}
import org.apache.spark.executor.{DataReadMethod, InputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.rdd.NewHadoopRDD.NewHadoopMapPartitionsWithSplitRDD
import org.apache.spark.util.Utils
import org.apache.spark.deploy.SparkHadoopUtil
@@ -107,20 +108,10 @@ class NewHadoopRDD[K, V](
val split = theSplit.asInstanceOf[NewHadoopPartition]
logInfo("Input split: " + split.serializableHadoopSplit)
val conf = confBroadcast.value.value
- val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
- val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
- val format = inputFormatClass.newInstance
- format match {
- case configurable: Configurable =>
- configurable.setConf(conf)
- case _ =>
- }
- val reader = format.createRecordReader(
- split.serializableHadoopSplit.value, hadoopAttemptContext)
- reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
val inputMetrics = new InputMetrics(DataReadMethod.Hadoop)
- // Find a function that will return the FileSystem bytes read by this thread.
+ // Find a function that will return the FileSystem bytes read by this thread. Do this before
+ // creating RecordReader, because RecordReader's constructor might read some bytes
val bytesReadCallback = if (split.serializableHadoopSplit.value.isInstanceOf[FileSplit]) {
SparkHadoopUtil.get.getFSBytesReadOnThreadCallback(
split.serializableHadoopSplit.value.asInstanceOf[FileSplit].getPath, conf)
@@ -131,6 +122,18 @@ class NewHadoopRDD[K, V](
context.taskMetrics.inputMetrics = Some(inputMetrics)
}
+ val attemptId = newTaskAttemptID(jobTrackerId, id, isMap = true, split.index, 0)
+ val hadoopAttemptContext = newTaskAttemptContext(conf, attemptId)
+ val format = inputFormatClass.newInstance
+ format match {
+ case configurable: Configurable =>
+ configurable.setConf(conf)
+ case _ =>
+ }
+ val reader = format.createRecordReader(
+ split.serializableHadoopSplit.value, hadoopAttemptContext)
+ reader.initialize(split.serializableHadoopSplit.value, hadoopAttemptContext)
+
// Register an on-task-completion callback to close the input stream.
context.addTaskCompletionListener(context => close())
var havePair = false
diff --git a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
index da89f634abaea..2c8bb657b521c 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PairRDDFunctions.scala
@@ -25,21 +25,24 @@ import scala.collection.{Map, mutable}
import scala.collection.JavaConversions._
import scala.collection.mutable.ArrayBuffer
import scala.reflect.ClassTag
+import scala.util.DynamicVariable
import com.clearspring.analytics.stream.cardinality.HyperLogLogPlus
import org.apache.hadoop.conf.{Configurable, Configuration}
-import org.apache.hadoop.fs.FileSystem
+import org.apache.hadoop.fs.{FileSystem, Path}
import org.apache.hadoop.io.SequenceFile.CompressionType
import org.apache.hadoop.io.compress.CompressionCodec
import org.apache.hadoop.mapred.{FileOutputCommitter, FileOutputFormat, JobConf, OutputFormat}
import org.apache.hadoop.mapreduce.{Job => NewAPIHadoopJob, OutputFormat => NewOutputFormat,
-RecordWriter => NewRecordWriter, SparkHadoopMapReduceUtil}
+RecordWriter => NewRecordWriter}
import org.apache.spark._
import org.apache.spark.Partitioner.defaultPartitioner
import org.apache.spark.SparkContext._
import org.apache.spark.annotation.Experimental
import org.apache.spark.deploy.SparkHadoopUtil
+import org.apache.spark.executor.{DataWriteMethod, OutputMetrics}
+import org.apache.spark.mapreduce.SparkHadoopMapReduceUtil
import org.apache.spark.partial.{BoundedDouble, PartialResult}
import org.apache.spark.serializer.Serializer
import org.apache.spark.util.Utils
@@ -84,7 +87,10 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
throw new SparkException("Default partitioner cannot partition array keys.")
}
}
- val aggregator = new Aggregator[K, V, C](createCombiner, mergeValue, mergeCombiners)
+ val aggregator = new Aggregator[K, V, C](
+ self.context.clean(createCombiner),
+ self.context.clean(mergeValue),
+ self.context.clean(mergeCombiners))
if (self.partitioner == Some(partitioner)) {
self.mapPartitions(iter => {
val context = TaskContext.get()
@@ -480,7 +486,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
*/
def join[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, W))] = {
this.cogroup(other, partitioner).flatMapValues( pair =>
- for (v <- pair._1; w <- pair._2) yield (v, w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, w)
)
}
@@ -493,9 +499,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def leftOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner): RDD[(K, (V, Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._2.isEmpty) {
- pair._1.map(v => (v, None))
+ pair._1.iterator.map(v => (v, None))
} else {
- for (v <- pair._1; w <- pair._2) yield (v, Some(w))
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (v, Some(w))
}
}
}
@@ -510,9 +516,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
: RDD[(K, (Option[V], W))] = {
this.cogroup(other, partitioner).flatMapValues { pair =>
if (pair._1.isEmpty) {
- pair._2.map(w => (None, w))
+ pair._2.iterator.map(w => (None, w))
} else {
- for (v <- pair._1; w <- pair._2) yield (Some(v), w)
+ for (v <- pair._1.iterator; w <- pair._2.iterator) yield (Some(v), w)
}
}
}
@@ -528,9 +534,9 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def fullOuterJoin[W](other: RDD[(K, W)], partitioner: Partitioner)
: RDD[(K, (Option[V], Option[W]))] = {
this.cogroup(other, partitioner).flatMapValues {
- case (vs, Seq()) => vs.map(v => (Some(v), None))
- case (Seq(), ws) => ws.map(w => (None, Some(w)))
- case (vs, ws) => for (v <- vs; w <- ws) yield (Some(v), Some(w))
+ case (vs, Seq()) => vs.iterator.map(v => (Some(v), None))
+ case (Seq(), ws) => ws.iterator.map(w => (None, Some(w)))
+ case (vs, ws) => for (v <- vs.iterator; w <- ws.iterator) yield (Some(v), Some(w))
}
}
@@ -955,36 +961,46 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
val outfmt = job.getOutputFormatClass
val jobFormat = outfmt.newInstance
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
jobFormat.checkOutputSpecs(job)
}
val writeShard = (context: TaskContext, iter: Iterator[(K,V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
/* "reduce task" */
val attemptId = newTaskAttemptID(jobtrackerID, stageId, isMap = false, context.partitionId,
attemptNumber)
- val hadoopContext = newTaskAttemptContext(wrappedConf.value, attemptId)
+ val hadoopContext = newTaskAttemptContext(config, attemptId)
val format = outfmt.newInstance
format match {
- case c: Configurable => c.setConf(wrappedConf.value)
+ case c: Configurable => c.setConf(config)
case _ => ()
}
val committer = format.getOutputCommitter(hadoopContext)
committer.setupTask(hadoopContext)
+
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+
val writer = format.getRecordWriter(hadoopContext).asInstanceOf[NewRecordWriter[K,V]]
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val pair = iter.next()
writer.write(pair._1, pair._2)
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close(hadoopContext)
}
committer.commitTask(hadoopContext)
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
1
} : Int
@@ -1005,6 +1021,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
def saveAsHadoopDataset(conf: JobConf) {
// Rename this as hadoopConf internally to avoid shadowing (see SPARK-2038).
val hadoopConf = conf
+ val wrappedConf = new SerializableWritable(hadoopConf)
val outputFormatInstance = hadoopConf.getOutputFormat
val keyClass = hadoopConf.getOutputKeyClass
val valueClass = hadoopConf.getOutputValueClass
@@ -1022,7 +1039,7 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
logDebug("Saving as hadoop file of type (" + keyClass.getSimpleName + ", " +
valueClass.getSimpleName + ")")
- if (self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)) {
+ if (isOutputSpecValidationEnabled) {
// FileOutputFormat ignores the filesystem parameter
val ignoredFs = FileSystem.get(hadoopConf)
hadoopConf.getOutputFormat.checkOutputSpecs(ignoredFs, hadoopConf)
@@ -1032,27 +1049,56 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
writer.preSetup()
val writeToFile = (context: TaskContext, iter: Iterator[(K, V)]) => {
+ val config = wrappedConf.value
// Hadoop wants a 32-bit task attempt ID, so if ours is bigger than Int.MaxValue, roll it
// around by taking a mod. We expect that no task will be attempted 2 billion times.
val attemptNumber = (context.attemptId % Int.MaxValue).toInt
+ val (outputMetrics, bytesWrittenCallback) = initHadoopOutputMetrics(context, config)
+
writer.setup(context.stageId, context.partitionId, attemptNumber)
writer.open()
try {
+ var recordsWritten = 0L
while (iter.hasNext) {
val record = iter.next()
writer.write(record._1.asInstanceOf[AnyRef], record._2.asInstanceOf[AnyRef])
+
+ // Update bytes written metric every few records
+ maybeUpdateOutputMetrics(bytesWrittenCallback, outputMetrics, recordsWritten)
+ recordsWritten += 1
}
} finally {
writer.close()
}
writer.commit()
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
}
self.context.runJob(self, writeToFile)
writer.commitJob()
}
+ private def initHadoopOutputMetrics(context: TaskContext, config: Configuration)
+ : (OutputMetrics, Option[() => Long]) = {
+ val bytesWrittenCallback = Option(config.get("mapreduce.output.fileoutputformat.outputdir"))
+ .map(new Path(_))
+ .flatMap(SparkHadoopUtil.get.getFSBytesWrittenOnThreadCallback(_, config))
+ val outputMetrics = new OutputMetrics(DataWriteMethod.Hadoop)
+ if (bytesWrittenCallback.isDefined) {
+ context.taskMetrics.outputMetrics = Some(outputMetrics)
+ }
+ (outputMetrics, bytesWrittenCallback)
+ }
+
+ private def maybeUpdateOutputMetrics(bytesWrittenCallback: Option[() => Long],
+ outputMetrics: OutputMetrics, recordsWritten: Long): Unit = {
+ if (recordsWritten % PairRDDFunctions.RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES == 0
+ && bytesWrittenCallback.isDefined) {
+ bytesWrittenCallback.foreach { fn => outputMetrics.bytesWritten = fn() }
+ }
+ }
+
/**
* Return an RDD with the keys of each tuple.
*/
@@ -1068,4 +1114,22 @@ class PairRDDFunctions[K, V](self: RDD[(K, V)])
private[spark] def valueClass: Class[_] = vt.runtimeClass
private[spark] def keyOrdering: Option[Ordering[K]] = Option(ord)
+
+ // Note: this needs to be a function instead of a 'val' so that the disableOutputSpecValidation
+ // setting can take effect:
+ private def isOutputSpecValidationEnabled: Boolean = {
+ val validationDisabled = PairRDDFunctions.disableOutputSpecValidation.value
+ val enabledInConf = self.conf.getBoolean("spark.hadoop.validateOutputSpecs", true)
+ enabledInConf && !validationDisabled
+ }
+}
+
+private[spark] object PairRDDFunctions {
+ val RECORDS_BETWEEN_BYTES_WRITTEN_METRIC_UPDATES = 256
+
+ /**
+ * Allows for the `spark.hadoop.validateOutputSpecs` checks to be disabled on a case-by-case
+ * basis; see SPARK-4835 for more details.
+ */
+ val disableOutputSpecValidation: DynamicVariable[Boolean] = new DynamicVariable[Boolean](false)
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
index 87b22de6ae697..f12d0cffaba34 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ParallelCollectionRDD.scala
@@ -111,7 +111,8 @@ private object ParallelCollectionRDD {
/**
* Slice a collection into numSlices sub-collections. One extra thing we do here is to treat Range
* collections specially, encoding the slices as other Ranges to minimize memory cost. This makes
- * it efficient to run Spark over RDDs representing large sets of numbers.
+ * it efficient to run Spark over RDDs representing large sets of numbers. And if the collection
+ * is an inclusive Range, we use inclusive range for the last slice.
*/
def slice[T: ClassTag](seq: Seq[T], numSlices: Int): Seq[Seq[T]] = {
if (numSlices < 1) {
@@ -127,19 +128,15 @@ private object ParallelCollectionRDD {
})
}
seq match {
- case r: Range.Inclusive => {
- val sign = if (r.step < 0) {
- -1
- } else {
- 1
- }
- slice(new Range(
- r.start, r.end + sign, r.step).asInstanceOf[Seq[T]], numSlices)
- }
case r: Range => {
- positions(r.length, numSlices).map({
- case (start, end) =>
+ positions(r.length, numSlices).zipWithIndex.map({ case ((start, end), index) =>
+ // If the range is inclusive, use inclusive range for the last slice
+ if (r.isInclusive && index == numSlices - 1) {
+ new Range.Inclusive(r.start + start * r.step, r.end, r.step)
+ }
+ else {
new Range(r.start + start * r.step, r.start + end * r.step, r.step)
+ }
}).toSeq.asInstanceOf[Seq[Seq[T]]]
}
case nr: NumericRange[_] => {
diff --git a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
index 56ac7a69be0d3..ed79032893d33 100644
--- a/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/PipedRDD.scala
@@ -63,7 +63,7 @@ private[spark] class PipedRDD[T: ClassTag](
/**
* A FilenameFilter that accepts anything that isn't equal to the name passed in.
- * @param name of file or directory to leave out
+ * @param filterName of file or directory to leave out
*/
class NotEqualsFileNameFilter(filterName: String) extends FilenameFilter {
def accept(dir: File, name: String): Boolean = {
diff --git a/core/src/main/scala/org/apache/spark/rdd/RDD.scala b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
index c169b2d3fe97f..1814318a8bf97 100644
--- a/core/src/main/scala/org/apache/spark/rdd/RDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/RDD.scala
@@ -75,10 +75,27 @@ import org.apache.spark.util.random.{BernoulliSampler, PoissonSampler, Bernoulli
* on RDD internals.
*/
abstract class RDD[T: ClassTag](
- @transient private var sc: SparkContext,
+ @transient private var _sc: SparkContext,
@transient private var deps: Seq[Dependency[_]]
) extends Serializable with Logging {
+ if (classOf[RDD[_]].isAssignableFrom(elementClassTag.runtimeClass)) {
+ // This is a warning instead of an exception in order to avoid breaking user programs that
+ // might have defined nested RDDs without running jobs with them.
+ logWarning("Spark does not support nested RDDs (see SPARK-5063)")
+ }
+
+ private def sc: SparkContext = {
+ if (_sc == null) {
+ throw new SparkException(
+ "RDD transformations and actions can only be invoked by the driver, not inside of other " +
+ "transformations; for example, rdd1.map(x => rdd2.values.count() * x) is invalid because " +
+ "the values transformation and count action cannot be performed inside of the rdd1.map " +
+ "transformation. For more information, see SPARK-5063.")
+ }
+ _sc
+ }
+
/** Construct an RDD with just a one-to-one dependency on one parent */
def this(@transient oneParent: RDD[_]) =
this(oneParent.context , List(new OneToOneDependency(oneParent)))
@@ -1096,7 +1113,7 @@ abstract class RDD[T: ClassTag](
}
/**
- * Returns the top K (largest) elements from this RDD as defined by the specified
+ * Returns the top k (largest) elements from this RDD as defined by the specified
* implicit Ordering[T]. This does the opposite of [[takeOrdered]]. For example:
* {{{
* sc.parallelize(Seq(10, 4, 2, 12, 3)).top(1)
@@ -1106,14 +1123,14 @@ abstract class RDD[T: ClassTag](
* // returns Array(6, 5)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of top elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
def top(num: Int)(implicit ord: Ordering[T]): Array[T] = takeOrdered(num)(ord.reverse)
/**
- * Returns the first K (smallest) elements from this RDD as defined by the specified
+ * Returns the first k (smallest) elements from this RDD as defined by the specified
* implicit Ordering[T] and maintains the ordering. This does the opposite of [[top]].
* For example:
* {{{
@@ -1124,7 +1141,7 @@ abstract class RDD[T: ClassTag](
* // returns Array(2, 3)
* }}}
*
- * @param num the number of top elements to return
+ * @param num k, the number of elements to return
* @param ord the implicit ordering for T
* @return an array of top elements
*/
@@ -1132,15 +1149,20 @@ abstract class RDD[T: ClassTag](
if (num == 0) {
Array.empty
} else {
- mapPartitions { items =>
+ val mapRDDs = mapPartitions { items =>
// Priority keeps the largest elements, so let's reverse the ordering.
val queue = new BoundedPriorityQueue[T](num)(ord.reverse)
queue ++= util.collection.Utils.takeOrdered(items, num)(ord)
Iterator.single(queue)
- }.reduce { (queue1, queue2) =>
- queue1 ++= queue2
- queue1
- }.toArray.sorted(ord)
+ }
+ if (mapRDDs.partitions.size == 0) {
+ Array.empty
+ } else {
+ mapRDDs.reduce { (queue1, queue2) =>
+ queue1 ++= queue2
+ queue1
+ }.toArray.sorted(ord)
+ }
}
}
@@ -1160,7 +1182,20 @@ abstract class RDD[T: ClassTag](
* Save this RDD as a text file, using string representations of elements.
*/
def saveAsTextFile(path: String) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ //
+ // NullWritable is a `Comparable` in Hadoop 1.+, so the compiler cannot find an implicit
+ // Ordering for it and will use the default `null`. However, it's a `Comparable[NullWritable]`
+ // in Hadoop 2.+, so the compiler will call the implicit `Ordering.ordered` method to create an
+ // Ordering for `NullWritable`. That's why the compiler will generate different anonymous
+ // classes for `saveAsTextFile` in Hadoop 1.+ and Hadoop 2.+.
+ //
+ // Therefore, here we provide an explicit Ordering `null` to make sure the compiler generate
+ // same bytecodes for `saveAsTextFile`.
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.map(x => (NullWritable.get(), new Text(x.toString)))
+ rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path)
}
@@ -1168,7 +1203,11 @@ abstract class RDD[T: ClassTag](
* Save this RDD as a compressed text file, using string representations of elements.
*/
def saveAsTextFile(path: String, codec: Class[_ <: CompressionCodec]) {
- this.map(x => (NullWritable.get(), new Text(x.toString)))
+ // https://issues.apache.org/jira/browse/SPARK-2075
+ val nullWritableClassTag = implicitly[ClassTag[NullWritable]]
+ val textClassTag = implicitly[ClassTag[Text]]
+ val r = this.map(x => (NullWritable.get(), new Text(x.toString)))
+ rddToPairRDDFunctions(r)(nullWritableClassTag, textClassTag, null)
.saveAsHadoopFile[TextOutputFormat[NullWritable, Text]](path, codec)
}
@@ -1202,7 +1241,7 @@ abstract class RDD[T: ClassTag](
*/
def checkpoint() {
if (context.checkpointDir.isEmpty) {
- throw new Exception("Checkpoint directory has not been set in the SparkContext")
+ throw new SparkException("Checkpoint directory has not been set in the SparkContext")
} else if (checkpointData.isEmpty) {
checkpointData = Some(new RDDCheckpointData(this))
checkpointData.get.markForCheckpoint()
@@ -1309,7 +1348,7 @@ abstract class RDD[T: ClassTag](
def debugSelf (rdd: RDD[_]): Seq[String] = {
import Utils.bytesToString
- val persistence = storageLevel.description
+ val persistence = if (storageLevel != StorageLevel.NONE) storageLevel.description else ""
val storageInfo = rdd.context.getRDDStorageInfo.filter(_.id == rdd.id).map(info =>
" CachedPartitions: %d; MemorySize: %s; TachyonSize: %s; DiskSize: %s".format(
info.numCachedPartitions, bytesToString(info.memSize),
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
index 996f2cd3f34a3..95b2dd954e9f4 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedPartitionsRDD.scala
@@ -77,7 +77,7 @@ private[spark] abstract class ZippedPartitionsBaseRDD[V: ClassTag](
private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
preservesPartitioning: Boolean = false)
@@ -92,13 +92,14 @@ private[spark] class ZippedPartitionsRDD2[A: ClassTag, B: ClassTag, V: ClassTag]
super.clearDependencies()
rdd1 = null
rdd2 = null
+ f = null
}
}
private[spark] class ZippedPartitionsRDD3
[A: ClassTag, B: ClassTag, C: ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B], Iterator[C]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
@@ -117,13 +118,14 @@ private[spark] class ZippedPartitionsRDD3
rdd1 = null
rdd2 = null
rdd3 = null
+ f = null
}
}
private[spark] class ZippedPartitionsRDD4
[A: ClassTag, B: ClassTag, C: ClassTag, D:ClassTag, V: ClassTag](
sc: SparkContext,
- f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
+ var f: (Iterator[A], Iterator[B], Iterator[C], Iterator[D]) => Iterator[V],
var rdd1: RDD[A],
var rdd2: RDD[B],
var rdd3: RDD[C],
@@ -145,5 +147,6 @@ private[spark] class ZippedPartitionsRDD4
rdd2 = null
rdd3 = null
rdd4 = null
+ f = null
}
}
diff --git a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
index e2c301603b4a5..8c43a559409f2 100644
--- a/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
+++ b/core/src/main/scala/org/apache/spark/rdd/ZippedWithIndexRDD.scala
@@ -39,21 +39,24 @@ class ZippedWithIndexRDDPartition(val prev: Partition, val startIndex: Long)
private[spark]
class ZippedWithIndexRDD[T: ClassTag](@transient prev: RDD[T]) extends RDD[(T, Long)](prev) {
- override def getPartitions: Array[Partition] = {
+ /** The start index of each partition. */
+ @transient private val startIndices: Array[Long] = {
val n = prev.partitions.size
- val startIndices: Array[Long] =
- if (n == 0) {
- Array[Long]()
- } else if (n == 1) {
- Array(0L)
- } else {
- prev.context.runJob(
- prev,
- Utils.getIteratorSize _,
- 0 until n - 1, // do not need to count the last partition
- false
- ).scanLeft(0L)(_ + _)
- }
+ if (n == 0) {
+ Array[Long]()
+ } else if (n == 1) {
+ Array(0L)
+ } else {
+ prev.context.runJob(
+ prev,
+ Utils.getIteratorSize _,
+ 0 until n - 1, // do not need to count the last partition
+ allowLocal = false
+ ).scanLeft(0L)(_ + _)
+ }
+ }
+
+ override def getPartitions: Array[Partition] = {
firstParent[T].partitions.map(x => new ZippedWithIndexRDDPartition(x, startIndices(x.index)))
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
index 96114c0423a9e..cb8ccfbdbdcbb 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/DAGScheduler.scala
@@ -449,7 +449,6 @@ class DAGScheduler(
}
// data structures based on StageId
stageIdToStage -= stageId
-
logDebug("After removal of stage %d, remaining stages = %d"
.format(stageId, stageIdToStage.size))
}
@@ -751,14 +750,15 @@ class DAGScheduler(
localExecutionEnabled && allowLocal && finalStage.parents.isEmpty && partitions.length == 1
if (shouldRunLocally) {
// Compute very short actions like first() or take() with no parent stages locally.
- listenerBus.post(SparkListenerJobStart(job.jobId, Array[Int](), properties))
+ listenerBus.post(SparkListenerJobStart(job.jobId, Seq.empty, properties))
runLocally(job)
} else {
jobIdToActiveJob(jobId) = job
activeJobs += job
finalStage.resultOfJob = Some(job)
- listenerBus.post(SparkListenerJobStart(job.jobId, jobIdToStageIds(jobId).toArray,
- properties))
+ val stageIds = jobIdToStageIds(jobId).toArray
+ val stageInfos = stageIds.flatMap(id => stageIdToStage.get(id).map(_.latestInfo))
+ listenerBus.post(SparkListenerJobStart(job.jobId, stageInfos, properties))
submitStage(finalStage)
}
}
@@ -901,6 +901,34 @@ class DAGScheduler(
}
}
+ /** Merge updates from a task to our local accumulator values */
+ private def updateAccumulators(event: CompletionEvent): Unit = {
+ val task = event.task
+ val stage = stageIdToStage(task.stageId)
+ if (event.accumUpdates != null) {
+ try {
+ Accumulators.add(event.accumUpdates)
+ event.accumUpdates.foreach { case (id, partialValue) =>
+ val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
+ // To avoid UI cruft, ignore cases where value wasn't updated
+ if (acc.name.isDefined && partialValue != acc.zero) {
+ val name = acc.name.get
+ val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
+ val stringValue = Accumulators.stringifyValue(acc.value)
+ stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
+ event.taskInfo.accumulables +=
+ AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
+ }
+ }
+ } catch {
+ // If we see an exception during accumulator update, just log the
+ // error and move on.
+ case e: Exception =>
+ logError(s"Failed to update accumulators for $task", e)
+ }
+ }
+ }
+
/**
* Responds to a task finishing. This is called inside the event loop so it assumes that it can
* modify the scheduler's internal state. Use taskEnded() to post a task end event from outside.
@@ -941,27 +969,6 @@ class DAGScheduler(
}
event.reason match {
case Success =>
- if (event.accumUpdates != null) {
- try {
- Accumulators.add(event.accumUpdates)
- event.accumUpdates.foreach { case (id, partialValue) =>
- val acc = Accumulators.originals(id).asInstanceOf[Accumulable[Any, Any]]
- // To avoid UI cruft, ignore cases where value wasn't updated
- if (acc.name.isDefined && partialValue != acc.zero) {
- val name = acc.name.get
- val stringPartialValue = Accumulators.stringifyPartialValue(partialValue)
- val stringValue = Accumulators.stringifyValue(acc.value)
- stage.latestInfo.accumulables(id) = AccumulableInfo(id, name, stringValue)
- event.taskInfo.accumulables +=
- AccumulableInfo(id, name, Some(stringPartialValue), stringValue)
- }
- }
- } catch {
- // If we see an exception during accumulator update, just log the error and move on.
- case e: Exception =>
- logError(s"Failed to update accumulators for $task", e)
- }
- }
listenerBus.post(SparkListenerTaskEnd(stageId, stage.latestInfo.attemptId, taskType,
event.reason, event.taskInfo, event.taskMetrics))
stage.pendingTasks -= task
@@ -970,6 +977,7 @@ class DAGScheduler(
stage.resultOfJob match {
case Some(job) =>
if (!job.finished(rt.outputId)) {
+ updateAccumulators(event)
job.finished(rt.outputId) = true
job.numFinished += 1
// If the whole job has finished, remove it
@@ -994,6 +1002,7 @@ class DAGScheduler(
}
case smt: ShuffleMapTask =>
+ updateAccumulators(event)
val status = event.result.asInstanceOf[MapStatus]
val execId = status.location.executorId
logDebug("ShuffleMapTask finished on " + execId)
@@ -1063,7 +1072,7 @@ class DAGScheduler(
if (runningStages.contains(failedStage)) {
logInfo(s"Marking $failedStage (${failedStage.name}) as failed " +
s"due to a fetch failure from $mapStage (${mapStage.name})")
- markStageAsFinished(failedStage, Some("Fetch failure: " + failureMessage))
+ markStageAsFinished(failedStage, Some(failureMessage))
runningStages -= failedStage
}
@@ -1082,7 +1091,6 @@ class DAGScheduler(
}
failedStages += failedStage
failedStages += mapStage
-
// Mark the map whose fetch failed as broken in the map stage
if (mapId != -1) {
mapStage.removeOutputLoc(mapId, bmAddress)
@@ -1094,7 +1102,7 @@ class DAGScheduler(
handleExecutorLost(bmAddress.executorId, fetchFailed = true, Some(task.epoch))
}
- case ExceptionFailure(className, description, stackTrace, metrics) =>
+ case ExceptionFailure(className, description, stackTrace, fullStackTrace, metrics) =>
// Do nothing here, left up to the TaskScheduler to decide how to handle user failures
case TaskResultLost =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
index 4e3d9de540783..3bb54855bae44 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/JobLogger.scala
@@ -158,6 +158,11 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" INPUT_BYTES=" + metrics.bytesRead
case None => ""
}
+ val outputMetrics = taskMetrics.outputMetrics match {
+ case Some(metrics) =>
+ " OUTPUT_BYTES=" + metrics.bytesWritten
+ case None => ""
+ }
val shuffleReadMetrics = taskMetrics.shuffleReadMetrics match {
case Some(metrics) =>
" BLOCK_FETCHED_TOTAL=" + metrics.totalBlocksFetched +
@@ -173,7 +178,7 @@ class JobLogger(val user: String, val logDirName: String) extends SparkListener
" SHUFFLE_WRITE_TIME=" + metrics.shuffleWriteTime
case None => ""
}
- stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics +
+ stageLogInfo(stageId, status + info + executorRunTime + gcTime + inputMetrics + outputMetrics +
shuffleReadMetrics + writeMetrics)
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
index 01d5943d777f3..1efce124c0a6b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/MapStatus.scala
@@ -122,7 +122,7 @@ private[spark] class CompressedMapStatus(
/**
* A [[MapStatus]] implementation that only stores the average size of non-empty blocks,
- * plus a bitmap for tracking which blocks are non-empty. During serialization, this bitmap
+ * plus a bitmap for tracking which blocks are empty. During serialization, this bitmap
* is compressed.
*
* @param loc location where the task is being executed
diff --git a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
index 86afe3bd5265f..b62b0c1312693 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/SparkListener.scala
@@ -56,8 +56,15 @@ case class SparkListenerTaskEnd(
extends SparkListenerEvent
@DeveloperApi
-case class SparkListenerJobStart(jobId: Int, stageIds: Seq[Int], properties: Properties = null)
- extends SparkListenerEvent
+case class SparkListenerJobStart(
+ jobId: Int,
+ stageInfos: Seq[StageInfo],
+ properties: Properties = null)
+ extends SparkListenerEvent {
+ // Note: this is here for backwards-compatibility with older versions of this event which
+ // only stored stageIds and not StageInfos:
+ val stageIds: Seq[Int] = stageInfos.map(_.stageId)
+}
@DeveloperApi
case class SparkListenerJobEnd(jobId: Int, jobResult: JobResult) extends SparkListenerEvent
diff --git a/core/src/main/scala/org/apache/spark/scheduler/Task.scala b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
index 2552d03d18d06..d7dde4fe38436 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/Task.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/Task.scala
@@ -45,7 +45,7 @@ import org.apache.spark.util.Utils
private[spark] abstract class Task[T](val stageId: Int, var partitionId: Int) extends Serializable {
final def run(attemptId: Long): T = {
- context = new TaskContextImpl(stageId, partitionId, attemptId, false)
+ context = new TaskContextImpl(stageId, partitionId, attemptId, runningLocally = false)
TaskContextHelper.setTaskContext(context)
context.taskMetrics.hostname = Utils.localHostName()
taskThread = Thread.currentThread()
diff --git a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
index d8fb640350343..cabdc655f89bf 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/TaskSetManager.scala
@@ -536,7 +536,7 @@ private[spark] class TaskSetManager(
calculatedTasks += 1
if (maxResultSize > 0 && totalResultSize > maxResultSize) {
val msg = s"Total size of serialized results of ${calculatedTasks} tasks " +
- s"(${Utils.bytesToString(totalResultSize)}) is bigger than maxResultSize " +
+ s"(${Utils.bytesToString(totalResultSize)}) is bigger than spark.driver.maxResultSize " +
s"(${Utils.bytesToString(maxResultSize)})"
logError(msg)
abort(msg)
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
index 7a6ee56f81689..fe9914b50bc54 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/CoarseGrainedSchedulerBackend.scala
@@ -27,7 +27,7 @@ import akka.actor._
import akka.pattern.ask
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
-import org.apache.spark.{SparkEnv, Logging, SparkException, TaskState}
+import org.apache.spark.{ExecutorAllocationClient, Logging, SparkEnv, SparkException, TaskState}
import org.apache.spark.scheduler.{SchedulerBackend, SlaveLost, TaskDescription, TaskSchedulerImpl, WorkerOffer}
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Utils}
@@ -42,10 +42,11 @@ import org.apache.spark.util.{ActorLogReceive, SerializableBuffer, AkkaUtils, Ut
*/
private[spark]
class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSystem: ActorSystem)
- extends SchedulerBackend with Logging
+ extends ExecutorAllocationClient with SchedulerBackend with Logging
{
// Use an atomic variable to track total number of cores in the cluster for simplicity and speed
var totalCoreCount = new AtomicInteger(0)
+ // Total number of executors that are currently registered
var totalRegisteredExecutors = new AtomicInteger(0)
val conf = scheduler.sc.conf
private val timeout = AkkaUtils.askTimeout(conf)
@@ -126,7 +127,13 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
makeOffers()
case KillTask(taskId, executorId, interruptThread) =>
- executorDataMap(executorId).executorActor ! KillTask(taskId, executorId, interruptThread)
+ executorDataMap.get(executorId) match {
+ case Some(executorInfo) =>
+ executorInfo.executorActor ! KillTask(taskId, executorId, interruptThread)
+ case None =>
+ // Ignoring the task kill since the executor is not registered.
+ logWarning(s"Attempted to kill task $taskId for unknown executor $executorId.")
+ }
case StopDriver =>
sender ! true
@@ -204,6 +211,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
executorsPendingToRemove -= executorId
}
totalCoreCount.addAndGet(-executorInfo.totalCores)
+ totalRegisteredExecutors.addAndGet(-1)
scheduler.executorLost(executorId, SlaveLost(reason))
case None => logError(s"Asked to remove non-existent executor $executorId")
}
@@ -299,7 +307,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* Request an additional number of executors from the cluster manager.
* Return whether the request is acknowledged.
*/
- final def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
+ final override def requestExecutors(numAdditionalExecutors: Int): Boolean = synchronized {
logInfo(s"Requesting $numAdditionalExecutors additional executor(s) from the cluster manager")
logDebug(s"Number of pending executors is now $numPendingExecutors")
numPendingExecutors += numAdditionalExecutors
@@ -326,7 +334,7 @@ class CoarseGrainedSchedulerBackend(scheduler: TaskSchedulerImpl, val actorSyste
* Request that the cluster manager kill the specified executors.
* Return whether the kill request is acknowledged.
*/
- final def killExecutors(executorIds: Seq[String]): Boolean = {
+ final override def killExecutors(executorIds: Seq[String]): Boolean = synchronized {
logInfo(s"Requesting to kill executor(s) ${executorIds.mkString(", ")}")
val filteredExecutorIds = new ArrayBuffer[String]
executorIds.foreach { id =>
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
index 50721b9d6cd6c..f14aaeea0a25c 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/YarnSchedulerBackend.scala
@@ -17,6 +17,8 @@
package org.apache.spark.scheduler.cluster
+import scala.concurrent.{Future, ExecutionContext}
+
import akka.actor.{Actor, ActorRef, Props}
import akka.remote.{DisassociatedEvent, RemotingLifecycleEvent}
@@ -24,7 +26,9 @@ import org.apache.spark.SparkContext
import org.apache.spark.scheduler.cluster.CoarseGrainedClusterMessages._
import org.apache.spark.scheduler.TaskSchedulerImpl
import org.apache.spark.ui.JettyUtils
-import org.apache.spark.util.AkkaUtils
+import org.apache.spark.util.{AkkaUtils, Utils}
+
+import scala.util.control.NonFatal
/**
* Abstract Yarn scheduler backend that contains common logic
@@ -97,6 +101,9 @@ private[spark] abstract class YarnSchedulerBackend(
private class YarnSchedulerActor extends Actor {
private var amActor: Option[ActorRef] = None
+ implicit val askAmActorExecutor = ExecutionContext.fromExecutor(
+ Utils.newDaemonCachedThreadPool("yarn-scheduler-ask-am-executor"))
+
override def preStart(): Unit = {
// Listen for disassociation events
context.system.eventStream.subscribe(self, classOf[RemotingLifecycleEvent])
@@ -110,7 +117,12 @@ private[spark] abstract class YarnSchedulerBackend(
case r: RequestExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](r, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $r to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to request executors before the AM has registered!")
sender ! false
@@ -119,7 +131,12 @@ private[spark] abstract class YarnSchedulerBackend(
case k: KillExecutors =>
amActor match {
case Some(actor) =>
- sender ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ val driverActor = sender
+ Future {
+ driverActor ! AkkaUtils.askWithReply[Boolean](k, actor, askTimeout)
+ } onFailure {
+ case NonFatal(e) => logError(s"Sending $k to AM was unsuccessful", e)
+ }
case None =>
logWarning("Attempted to kill executors before the AM has registered!")
sender ! false
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
index d8c0e2f66df01..5289661eb896b 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/CoarseMesosSchedulerBackend.scala
@@ -93,7 +93,7 @@ private[spark] class CoarseMesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = CoarseMesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try { {
val ret = driver.run()
@@ -242,8 +242,7 @@ private[spark] class CoarseMesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Build a Mesos resource protobuf object */
diff --git a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
index 8e2faff90f9b2..10e6886c16a4f 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/cluster/mesos/MesosSchedulerBackend.scala
@@ -72,7 +72,7 @@ private[spark] class MesosSchedulerBackend(
setDaemon(true)
override def run() {
val scheduler = MesosSchedulerBackend.this
- val fwInfo = FrameworkInfo.newBuilder().setUser("").setName(sc.appName).build()
+ val fwInfo = FrameworkInfo.newBuilder().setUser(sc.sparkUser).setName(sc.appName).build()
driver = new MesosSchedulerDriver(scheduler, fwInfo, master)
try {
val ret = driver.run()
@@ -166,29 +166,16 @@ private[spark] class MesosSchedulerBackend(
execArgs
}
- private def setClassLoader(): ClassLoader = {
- val oldClassLoader = Thread.currentThread.getContextClassLoader
- Thread.currentThread.setContextClassLoader(classLoader)
- oldClassLoader
- }
-
- private def restoreClassLoader(oldClassLoader: ClassLoader) {
- Thread.currentThread.setContextClassLoader(oldClassLoader)
- }
-
override def offerRescinded(d: SchedulerDriver, o: OfferID) {}
override def registered(d: SchedulerDriver, frameworkId: FrameworkID, masterInfo: MasterInfo) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
appId = frameworkId.getValue
logInfo("Registered as framework ID " + appId)
registeredLock.synchronized {
isRegistered = true
registeredLock.notifyAll()
}
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -200,6 +187,16 @@ private[spark] class MesosSchedulerBackend(
}
}
+ private def inClassLoader()(fun: => Unit) = {
+ val oldClassLoader = Thread.currentThread.getContextClassLoader
+ Thread.currentThread.setContextClassLoader(classLoader)
+ try {
+ fun
+ } finally {
+ Thread.currentThread.setContextClassLoader(oldClassLoader)
+ }
+ }
+
override def disconnected(d: SchedulerDriver) {}
override def reregistered(d: SchedulerDriver, masterInfo: MasterInfo) {}
@@ -210,66 +207,70 @@ private[spark] class MesosSchedulerBackend(
* tasks are balanced across the cluster.
*/
override def resourceOffers(d: SchedulerDriver, offers: JList[Offer]) {
- val oldClassLoader = setClassLoader()
- try {
- synchronized {
- // Build a big list of the offerable workers, and remember their indices so that we can
- // figure out which Offer to reply to for each worker
- val offerableWorkers = new ArrayBuffer[WorkerOffer]
- val offerableIndices = new HashMap[String, Int]
-
- def sufficientOffer(o: Offer) = {
- val mem = getResource(o.getResourcesList, "mem")
- val cpus = getResource(o.getResourcesList, "cpus")
- val slaveId = o.getSlaveId.getValue
- (mem >= MemoryUtils.calculateTotalMemory(sc) &&
- // need at least 1 for executor, 1 for task
- cpus >= 2 * scheduler.CPUS_PER_TASK) ||
- (slaveIdsWithExecutors.contains(slaveId) &&
- cpus >= scheduler.CPUS_PER_TASK)
- }
+ inClassLoader() {
+ // Fail-fast on offers we know will be rejected
+ val (usableOffers, unUsableOffers) = offers.partition { o =>
+ val mem = getResource(o.getResourcesList, "mem")
+ val cpus = getResource(o.getResourcesList, "cpus")
+ val slaveId = o.getSlaveId.getValue
+ // TODO(pwendell): Should below be 1 + scheduler.CPUS_PER_TASK?
+ (mem >= MemoryUtils.calculateTotalMemory(sc) &&
+ // need at least 1 for executor, 1 for task
+ cpus >= 2 * scheduler.CPUS_PER_TASK) ||
+ (slaveIdsWithExecutors.contains(slaveId) &&
+ cpus >= scheduler.CPUS_PER_TASK)
+ }
- for ((offer, index) <- offers.zipWithIndex if sufficientOffer(offer)) {
- val slaveId = offer.getSlaveId.getValue
- offerableIndices.put(slaveId, index)
- val cpus = if (slaveIdsWithExecutors.contains(slaveId)) {
- getResource(offer.getResourcesList, "cpus").toInt
- } else {
- // If the executor doesn't exist yet, subtract CPU for executor
- getResource(offer.getResourcesList, "cpus").toInt -
- scheduler.CPUS_PER_TASK
- }
- offerableWorkers += new WorkerOffer(
- offer.getSlaveId.getValue,
- offer.getHostname,
- cpus)
+ val workerOffers = usableOffers.map { o =>
+ val cpus = if (slaveIdsWithExecutors.contains(o.getSlaveId.getValue)) {
+ getResource(o.getResourcesList, "cpus").toInt
+ } else {
+ // If the executor doesn't exist yet, subtract CPU for executor
+ // TODO(pwendell): Should below just subtract "1"?
+ getResource(o.getResourcesList, "cpus").toInt -
+ scheduler.CPUS_PER_TASK
}
+ new WorkerOffer(
+ o.getSlaveId.getValue,
+ o.getHostname,
+ cpus)
+ }
+
+ val slaveIdToOffer = usableOffers.map(o => o.getSlaveId.getValue -> o).toMap
+
+ val mesosTasks = new HashMap[String, JArrayList[MesosTaskInfo]]
- // Call into the TaskSchedulerImpl
- val taskLists = scheduler.resourceOffers(offerableWorkers)
-
- // Build a list of Mesos tasks for each slave
- val mesosTasks = offers.map(o => new JArrayList[MesosTaskInfo]())
- for ((taskList, index) <- taskLists.zipWithIndex) {
- if (!taskList.isEmpty) {
- for (taskDesc <- taskList) {
- val slaveId = taskDesc.executorId
- val offerNum = offerableIndices(slaveId)
- slaveIdsWithExecutors += slaveId
- taskIdToSlaveId(taskDesc.taskId) = slaveId
- mesosTasks(offerNum).add(createMesosTask(taskDesc, slaveId))
- }
+ val slavesIdsOfAcceptedOffers = HashSet[String]()
+
+ // Call into the TaskSchedulerImpl
+ val acceptedOffers = scheduler.resourceOffers(workerOffers).filter(!_.isEmpty)
+ acceptedOffers
+ .foreach { offer =>
+ offer.foreach { taskDesc =>
+ val slaveId = taskDesc.executorId
+ slaveIdsWithExecutors += slaveId
+ slavesIdsOfAcceptedOffers += slaveId
+ taskIdToSlaveId(taskDesc.taskId) = slaveId
+ mesosTasks.getOrElseUpdate(slaveId, new JArrayList[MesosTaskInfo])
+ .add(createMesosTask(taskDesc, slaveId))
}
}
- // Reply to the offers
- val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
- for (i <- 0 until offers.size) {
- d.launchTasks(Collections.singleton(offers(i).getId), mesosTasks(i), filters)
- }
+ // Reply to the offers
+ val filters = Filters.newBuilder().setRefuseSeconds(1).build() // TODO: lower timeout?
+
+ mesosTasks.foreach { case (slaveId, tasks) =>
+ d.launchTasks(Collections.singleton(slaveIdToOffer(slaveId).getId), tasks, filters)
}
- } finally {
- restoreClassLoader(oldClassLoader)
+
+ // Decline offers that weren't used
+ // NOTE: This logic assumes that we only get a single offer for each host in a given batch
+ for (o <- usableOffers if !slavesIdsOfAcceptedOffers.contains(o.getSlaveId.getValue)) {
+ d.declineOffer(o.getId)
+ }
+
+ // Decline offers we ruled out immediately
+ unUsableOffers.foreach(o => d.declineOffer(o.getId))
}
}
@@ -278,8 +279,7 @@ private[spark] class MesosSchedulerBackend(
for (r <- res if r.getName == name) {
return r.getScalar.getValue
}
- // If we reached here, no resource with the required name was present
- throw new IllegalArgumentException("No resource called " + name + " in " + res)
+ 0
}
/** Turn a Spark TaskDescription into a Mesos task */
@@ -309,8 +309,7 @@ private[spark] class MesosSchedulerBackend(
}
override def statusUpdate(d: SchedulerDriver, status: TaskStatus) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
val tid = status.getTaskId.getValue.toLong
val state = TaskState.fromMesos(status.getState)
synchronized {
@@ -323,18 +322,13 @@ private[spark] class MesosSchedulerBackend(
}
}
scheduler.statusUpdate(tid, state, status.getData.asReadOnlyByteBuffer)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
override def error(d: SchedulerDriver, message: String) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logError("Mesos error: " + message)
scheduler.error(message)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
@@ -351,15 +345,12 @@ private[spark] class MesosSchedulerBackend(
override def frameworkMessage(d: SchedulerDriver, e: ExecutorID, s: SlaveID, b: Array[Byte]) {}
private def recordSlaveLost(d: SchedulerDriver, slaveId: SlaveID, reason: ExecutorLossReason) {
- val oldClassLoader = setClassLoader()
- try {
+ inClassLoader() {
logInfo("Mesos slave lost: " + slaveId.getValue)
synchronized {
slaveIdsWithExecutors -= slaveId.getValue
}
scheduler.executorLost(slaveId.getValue, reason)
- } finally {
- restoreClassLoader(oldClassLoader)
}
}
diff --git a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
index c0264836de738..a2f1f14264a99 100644
--- a/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
+++ b/core/src/main/scala/org/apache/spark/scheduler/local/LocalBackend.scala
@@ -51,7 +51,7 @@ private[spark] class LocalActor(
private val localExecutorHostname = "localhost"
val executor = new Executor(
- localExecutorId, localExecutorHostname, scheduler.conf.getAll, isLocal = true)
+ localExecutorId, localExecutorHostname, scheduler.conf.getAll, totalCores, isLocal = true)
override def receiveWithLogging = {
case ReviveOffers =>
diff --git a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
index 662a7b91248aa..fa8a337ad63a8 100644
--- a/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/JavaSerializer.scala
@@ -92,7 +92,7 @@ private[spark] class JavaSerializerInstance(counterReset: Int, defaultClassLoade
}
override def deserializeStream(s: InputStream): DeserializationStream = {
- new JavaDeserializationStream(s, Utils.getContextOrSparkClassLoader)
+ new JavaDeserializationStream(s, defaultClassLoader)
}
def deserializeStream(s: InputStream, loader: ClassLoader): DeserializationStream = {
diff --git a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
index 621a951c27d07..d56e23ce4478a 100644
--- a/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
+++ b/core/src/main/scala/org/apache/spark/serializer/KryoSerializer.scala
@@ -26,9 +26,10 @@ import com.esotericsoftware.kryo.serializers.{JavaSerializer => KryoJavaSerializ
import com.twitter.chill.{AllScalaRegistrar, EmptyScalaKryoInstantiator}
import org.apache.spark._
+import org.apache.spark.api.python.PythonBroadcast
import org.apache.spark.broadcast.HttpBroadcast
import org.apache.spark.network.nio.{PutBlock, GotBlock, GetBlock}
-import org.apache.spark.scheduler.MapStatus
+import org.apache.spark.scheduler.{CompressedMapStatus, HighlyCompressedMapStatus}
import org.apache.spark.storage._
import org.apache.spark.util.BoundedPriorityQueue
import org.apache.spark.util.collection.CompactBuffer
@@ -90,6 +91,7 @@ class KryoSerializer(conf: SparkConf)
// Allow sending SerializableWritable
kryo.register(classOf[SerializableWritable[_]], new KryoJavaSerializer())
kryo.register(classOf[HttpBroadcast[_]], new KryoJavaSerializer())
+ kryo.register(classOf[PythonBroadcast], new KryoJavaSerializer())
try {
// Use the default classloader when calling the user registrator.
@@ -205,7 +207,8 @@ private[serializer] object KryoSerializer {
classOf[PutBlock],
classOf[GotBlock],
classOf[GetBlock],
- classOf[MapStatus],
+ classOf[CompressedMapStatus],
+ classOf[HighlyCompressedMapStatus],
classOf[CompactBuffer[_]],
classOf[BlockManagerId],
classOf[Array[Byte]],
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
index 0c1b6f4defdb3..be184464e0ae9 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FetchFailedException.scala
@@ -32,10 +32,21 @@ private[spark] class FetchFailedException(
shuffleId: Int,
mapId: Int,
reduceId: Int,
- message: String)
- extends Exception(message) {
+ message: String,
+ cause: Throwable = null)
+ extends Exception(message, cause) {
+
+ def this(
+ bmAddress: BlockManagerId,
+ shuffleId: Int,
+ mapId: Int,
+ reduceId: Int,
+ cause: Throwable) {
+ this(bmAddress, shuffleId, mapId, reduceId, cause.getMessage, cause)
+ }
- def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId, message)
+ def toTaskEndReason: TaskEndReason = FetchFailed(bmAddress, shuffleId, mapId, reduceId,
+ Utils.exceptionString(this))
}
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
index f03e8e4bf1b7e..7de2f9cbb2866 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/FileShuffleBlockManager.scala
@@ -27,6 +27,7 @@ import scala.collection.JavaConversions._
import org.apache.spark.{Logging, SparkConf, SparkEnv}
import org.apache.spark.executor.ShuffleWriteMetrics
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FileShuffleBlockManager.ShuffleFileGroup
import org.apache.spark.storage._
@@ -68,6 +69,8 @@ private[spark]
class FileShuffleBlockManager(conf: SparkConf)
extends ShuffleBlockManager with Logging {
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
private lazy val blockManager = SparkEnv.get.blockManager
// Turning off shuffle file consolidation causes all shuffle Blocks to get their own file.
@@ -182,13 +185,14 @@ class FileShuffleBlockManager(conf: SparkConf)
val segmentOpt = iter.next.getFileSegmentFor(blockId.mapId, blockId.reduceId)
if (segmentOpt.isDefined) {
val segment = segmentOpt.get
- return new FileSegmentManagedBuffer(segment.file, segment.offset, segment.length)
+ return new FileSegmentManagedBuffer(
+ transportConf, segment.file, segment.offset, segment.length)
}
}
throw new IllegalStateException("Failed to find shuffle block: " + blockId)
} else {
val file = blockManager.diskBlockManager.getFile(blockId)
- new FileSegmentManagedBuffer(file, 0, file.length)
+ new FileSegmentManagedBuffer(transportConf, file, 0, file.length)
}
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
index a48f0c9eceb5e..b292587d37028 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/IndexShuffleBlockManager.scala
@@ -22,8 +22,9 @@ import java.nio.ByteBuffer
import com.google.common.io.ByteStreams
-import org.apache.spark.SparkEnv
+import org.apache.spark.{SparkConf, SparkEnv}
import org.apache.spark.network.buffer.{FileSegmentManagedBuffer, ManagedBuffer}
+import org.apache.spark.network.netty.SparkTransportConf
import org.apache.spark.storage._
/**
@@ -38,10 +39,12 @@ import org.apache.spark.storage._
// Note: Changes to the format in this file should be kept in sync with
// org.apache.spark.network.shuffle.StandaloneShuffleBlockManager#getSortBasedShuffleBlockData().
private[spark]
-class IndexShuffleBlockManager extends ShuffleBlockManager {
+class IndexShuffleBlockManager(conf: SparkConf) extends ShuffleBlockManager {
private lazy val blockManager = SparkEnv.get.blockManager
+ private val transportConf = SparkTransportConf.fromSparkConf(conf)
+
/**
* Mapping to a single shuffleBlockId with reduce ID 0.
* */
@@ -109,6 +112,7 @@ class IndexShuffleBlockManager extends ShuffleBlockManager {
val offset = in.readLong()
val nextOffset = in.readLong()
new FileSegmentManagedBuffer(
+ transportConf,
getDataFile(blockId.shuffleId, blockId.mapId),
offset,
nextOffset - offset)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
index 801ae54086053..a44a8e1249256 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleManager.scala
@@ -20,8 +20,8 @@ package org.apache.spark.shuffle
import org.apache.spark.{TaskContext, ShuffleDependency}
/**
- * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on both the
- * driver and executors, based on the spark.shuffle.manager setting. The driver registers shuffles
+ * Pluggable interface for shuffle systems. A ShuffleManager is created in SparkEnv on the driver
+ * and on each executor, based on the spark.shuffle.manager setting. The driver registers shuffles
* with it, and executors (or tasks running locally in the driver) can ask to read and write data.
*
* NOTE: this will be instantiated by SparkEnv so its constructor can take a SparkConf and
diff --git a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
index ee91a368b76ea..3bcc7178a3d8b 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/ShuffleMemoryManager.scala
@@ -66,8 +66,9 @@ private[spark] class ShuffleMemoryManager(maxMemory: Long) extends Logging {
val curMem = threadMemory(threadId)
val freeMemory = maxMemory - threadMemory.values.sum
- // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads
- val maxToGrant = math.min(numBytes, (maxMemory / numActiveThreads) - curMem)
+ // How much we can grant this thread; don't let it grow to more than 1 / numActiveThreads;
+ // don't let it be negative
+ val maxToGrant = math.min(numBytes, math.max(0, (maxMemory / numActiveThreads) - curMem))
if (curMem < maxMemory / (2 * numActiveThreads)) {
// We want to let each thread get at least 1 / (2 * numActiveThreads) before blocking;
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
index 0d5247f4176d4..e3e7434df45b0 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/BlockStoreShuffleFetcher.scala
@@ -25,7 +25,7 @@ import org.apache.spark._
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.FetchFailedException
import org.apache.spark.storage.{BlockId, BlockManagerId, ShuffleBlockFetcherIterator, ShuffleBlockId}
-import org.apache.spark.util.{CompletionIterator, Utils}
+import org.apache.spark.util.CompletionIterator
private[hash] object BlockStoreShuffleFetcher extends Logging {
def fetch[T](
@@ -64,8 +64,7 @@ private[hash] object BlockStoreShuffleFetcher extends Logging {
blockId match {
case ShuffleBlockId(shufId, mapId, _) =>
val address = statuses(mapId.toInt)._1
- throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId,
- Utils.exceptionString(e))
+ throw new FetchFailedException(address, shufId.toInt, mapId.toInt, reduceId, e)
case _ =>
throw new SparkException(
"Failed to get block " + blockId + ", which is not a shuffle block", e)
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
index 5baf45db45c17..de72148ccc7ac 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleReader.scala
@@ -45,9 +45,9 @@ private[spark] class HashShuffleReader[K, C](
} else {
new InterruptibleIterator(context, dep.aggregator.get.combineValuesByKey(iter, context))
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
+
// Convert the Product2s to pairs since this is what downstream RDDs currently expect
iter.asInstanceOf[Iterator[Product2[K, C]]].map(pair => (pair._1, pair._2))
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
index 183a30373b28c..755f17d6aa15a 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/hash/HashShuffleWriter.scala
@@ -56,9 +56,8 @@ private[spark] class HashShuffleWriter[K, V](
} else {
records
}
- } else if (dep.aggregator.isEmpty && dep.mapSideCombine) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
} else {
+ require(!dep.mapSideCombine, "Map-side combine without Aggregator specified!")
records
}
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
index b727438ae7e47..bda30a56d808e 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleManager.scala
@@ -25,7 +25,7 @@ import org.apache.spark.shuffle.hash.HashShuffleReader
private[spark] class SortShuffleManager(conf: SparkConf) extends ShuffleManager {
- private val indexShuffleBlockManager = new IndexShuffleBlockManager()
+ private val indexShuffleBlockManager = new IndexShuffleBlockManager(conf)
private val shuffleMapNumber = new ConcurrentHashMap[Int, Int]()
/**
diff --git a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
index d75f9d7311fad..27496c5a289cb 100644
--- a/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
+++ b/core/src/main/scala/org/apache/spark/shuffle/sort/SortShuffleWriter.scala
@@ -50,9 +50,7 @@ private[spark] class SortShuffleWriter[K, V, C](
/** Write a bunch of records to this task's output */
override def write(records: Iterator[_ <: Product2[K, V]]): Unit = {
if (dep.mapSideCombine) {
- if (!dep.aggregator.isDefined) {
- throw new IllegalStateException("Aggregator is empty for map-side combine")
- }
+ require(dep.aggregator.isDefined, "Map-side combine without Aggregator specified!")
sorter = new ExternalSorter[K, V, C](
dep.aggregator, Some(dep.partitioner), dep.keyOrdering, dep.serializer)
sorter.insertAll(records)
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
index 5f5dd0dc1c63f..d7b184f8a10e9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManager.scala
@@ -35,12 +35,12 @@ import org.apache.spark.io.CompressionCodec
import org.apache.spark.network._
import org.apache.spark.network.buffer.{ManagedBuffer, NioManagedBuffer}
import org.apache.spark.network.netty.{SparkTransportConf, NettyBlockTransferService}
-import org.apache.spark.network.shuffle.{ExecutorShuffleInfo, ExternalShuffleClient}
+import org.apache.spark.network.shuffle.ExternalShuffleClient
+import org.apache.spark.network.shuffle.protocol.ExecutorShuffleInfo
import org.apache.spark.network.util.{ConfigProvider, TransportConf}
import org.apache.spark.serializer.Serializer
import org.apache.spark.shuffle.ShuffleManager
import org.apache.spark.shuffle.hash.HashShuffleManager
-import org.apache.spark.shuffle.sort.SortShuffleManager
import org.apache.spark.util._
private[spark] sealed trait BlockValues
@@ -57,6 +57,12 @@ private[spark] class BlockResult(
inputMetrics.bytesRead = bytes
}
+/**
+ * Manager running on every node (driver and executors) which provides interfaces for putting and
+ * retrieving blocks both locally and remotely into various stores (memory, disk, and off-heap).
+ *
+ * Note that #initialize() must be called before the BlockManager is usable.
+ */
private[spark] class BlockManager(
executorId: String,
actorSystem: ActorSystem,
@@ -66,11 +72,11 @@ private[spark] class BlockManager(
val conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService)
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int)
extends BlockDataManager with Logging {
- blockTransferService.init(this)
-
val diskBlockManager = new DiskBlockManager(this, conf)
private val blockInfo = new TimeStampedHashMap[BlockId, BlockInfo]
@@ -92,7 +98,12 @@ private[spark] class BlockManager(
private[spark]
val externalShuffleServiceEnabled = conf.getBoolean("spark.shuffle.service.enabled", false)
- private val externalShuffleServicePort = conf.getInt("spark.shuffle.service.port", 7337)
+
+ // Port used by the external shuffle service. In Yarn mode, this may be already be
+ // set through the Hadoop configuration as the server is launched in the Yarn NM.
+ private val externalShuffleServicePort =
+ Utils.getSparkOrYarnConfig(conf, "spark.shuffle.service.port", "7337").toInt
+
// Check that we're not using external shuffle service with consolidated shuffle files.
if (externalShuffleServiceEnabled
&& conf.getBoolean("spark.shuffle.consolidateFiles", false)
@@ -102,22 +113,17 @@ private[spark] class BlockManager(
+ " switch to sort-based shuffle.")
}
- val blockManagerId = BlockManagerId(
- executorId, blockTransferService.hostName, blockTransferService.port)
+ var blockManagerId: BlockManagerId = _
// Address of the server that serves this executor's shuffle files. This is either an external
// service, or just our own Executor's BlockManager.
- private[spark] val shuffleServerId = if (externalShuffleServiceEnabled) {
- BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
- } else {
- blockManagerId
- }
+ private[spark] var shuffleServerId: BlockManagerId = _
// Client to read other executors' shuffle files. This is either an external service, or just the
// standard BlockTranserService to directly connect to other Executors.
private[spark] val shuffleClient = if (externalShuffleServiceEnabled) {
- val appId = conf.get("spark.app.id", "unknown-app-id")
- new ExternalShuffleClient(SparkTransportConf.fromSparkConf(conf), appId)
+ val transConf = SparkTransportConf.fromSparkConf(conf, numUsableCores)
+ new ExternalShuffleClient(transConf, securityManager, securityManager.isAuthenticationEnabled())
} else {
blockTransferService
}
@@ -150,8 +156,6 @@ private[spark] class BlockManager(
private val peerFetchLock = new Object
private var lastPeerFetchTime = 0L
- initialize()
-
/* The compression codec to use. Note that the "lazy" val is necessary because we want to delay
* the initialization of the compression codec until it is first used. The reason is that a Spark
* program could be using a user-defined codec in a third party jar, which is loaded in
@@ -170,16 +174,35 @@ private[spark] class BlockManager(
conf: SparkConf,
mapOutputTracker: MapOutputTracker,
shuffleManager: ShuffleManager,
- blockTransferService: BlockTransferService) = {
+ blockTransferService: BlockTransferService,
+ securityManager: SecurityManager,
+ numUsableCores: Int) = {
this(execId, actorSystem, master, serializer, BlockManager.getMaxMemory(conf),
- conf, mapOutputTracker, shuffleManager, blockTransferService)
+ conf, mapOutputTracker, shuffleManager, blockTransferService, securityManager, numUsableCores)
}
/**
- * Initialize the BlockManager. Register to the BlockManagerMaster, and start the
- * BlockManagerWorker actor. Additionally registers with a local shuffle service if configured.
+ * Initializes the BlockManager with the given appId. This is not performed in the constructor as
+ * the appId may not be known at BlockManager instantiation time (in particular for the driver,
+ * where it is only learned after registration with the TaskScheduler).
+ *
+ * This method initializes the BlockTransferService and ShuffleClient, registers with the
+ * BlockManagerMaster, starts the BlockManagerWorker actor, and registers with a local shuffle
+ * service if configured.
*/
- private def initialize(): Unit = {
+ def initialize(appId: String): Unit = {
+ blockTransferService.init(this)
+ shuffleClient.init(appId)
+
+ blockManagerId = BlockManagerId(
+ executorId, blockTransferService.hostName, blockTransferService.port)
+
+ shuffleServerId = if (externalShuffleServiceEnabled) {
+ BlockManagerId(executorId, blockTransferService.hostName, externalShuffleServicePort)
+ } else {
+ blockManagerId
+ }
+
master.registerBlockManager(blockManagerId, maxMemory, slaveActor)
// Register Executors' configuration with the local shuffle service, if one should exist.
@@ -206,7 +229,6 @@ private[spark] class BlockManager(
return
} catch {
case e: Exception if i < MAX_ATTEMPTS =>
- val attemptsRemaining =
logError(s"Failed to connect to external shuffle server, will retry ${MAX_ATTEMPTS - i}}"
+ s" more times after waiting $SLEEP_TIME_SECS seconds...", e)
Thread.sleep(SLEEP_TIME_SECS * 1000)
@@ -920,7 +942,7 @@ private[spark] class BlockManager(
data.rewind()
logTrace(s"Trying to replicate $blockId of ${data.limit()} bytes to $peer")
blockTransferService.uploadBlockSync(
- peer.host, peer.port, blockId, new NioManagedBuffer(data), tLevel)
+ peer.host, peer.port, peer.executorId, blockId, new NioManagedBuffer(data), tLevel)
logTrace(s"Replicated $blockId of ${data.limit()} bytes to $peer in %s ms"
.format(System.currentTimeMillis - onePeerStartTime))
peersReplicatedTo += peer
@@ -992,8 +1014,10 @@ private[spark] class BlockManager(
// If we get here, the block write failed.
logWarning(s"Block $blockId was marked as failure. Nothing to drop")
return None
+ } else if (blockInfo.get(blockId).isEmpty) {
+ logWarning(s"Block $blockId was already dropped.")
+ return None
}
-
var blockIsUpdated = false
val level = info.level
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
index d08e1419e3e41..b63c7f191155c 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMaster.scala
@@ -88,6 +88,10 @@ class BlockManagerMaster(
askDriverWithReply[Seq[BlockManagerId]](GetPeers(blockManagerId))
}
+ def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ askDriverWithReply[Option[(String, Int)]](GetActorSystemHostPortForExecutor(executorId))
+ }
+
/**
* Remove a block from the slaves that have it. This can only be used to remove
* blocks that the driver knows about.
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
index 5e375a2553979..685b2e11440fb 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMasterActor.scala
@@ -86,6 +86,9 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
case GetPeers(blockManagerId) =>
sender ! getPeers(blockManagerId)
+ case GetActorSystemHostPortForExecutor(executorId) =>
+ sender ! getActorSystemHostPortForExecutor(executorId)
+
case GetMemoryStatus =>
sender ! memoryStatus
@@ -412,6 +415,21 @@ class BlockManagerMasterActor(val isLocal: Boolean, conf: SparkConf, listenerBus
Seq.empty
}
}
+
+ /**
+ * Returns the hostname and port of an executor's actor system, based on the Akka address of its
+ * BlockManagerSlaveActor.
+ */
+ private def getActorSystemHostPortForExecutor(executorId: String): Option[(String, Int)] = {
+ for (
+ blockManagerId <- blockManagerIdByExecutor.get(executorId);
+ info <- blockManagerInfo.get(blockManagerId);
+ host <- info.slaveActor.path.address.host;
+ port <- info.slaveActor.path.address.port
+ ) yield {
+ (host, port)
+ }
+ }
}
@DeveloperApi
diff --git a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
index 291ddfcc113ac..3f32099d08cc9 100644
--- a/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
+++ b/core/src/main/scala/org/apache/spark/storage/BlockManagerMessages.scala
@@ -92,6 +92,8 @@ private[spark] object BlockManagerMessages {
case class GetPeers(blockManagerId: BlockManagerId) extends ToBlockManagerMaster
+ case class GetActorSystemHostPortForExecutor(executorId: String) extends ToBlockManagerMaster
+
case class RemoveExecutor(execId: String) extends ToBlockManagerMaster
case object StopBlockManagerMaster extends ToBlockManagerMaster
diff --git a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
index 58fba54710510..ffaac4b17657c 100644
--- a/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/DiskBlockManager.scala
@@ -17,9 +17,8 @@
package org.apache.spark.storage
-import java.io.File
-import java.text.SimpleDateFormat
-import java.util.{Date, Random, UUID}
+import java.util.UUID
+import java.io.{IOException, File}
import org.apache.spark.{SparkConf, Logging}
import org.apache.spark.executor.ExecutorExitCode
@@ -37,7 +36,6 @@ import org.apache.spark.util.Utils
private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkConf)
extends Logging {
- private val MAX_DIR_CREATION_ATTEMPTS: Int = 10
private[spark]
val subDirsPerLocalDir = blockManager.conf.getInt("spark.diskStore.subDirectories", 64)
@@ -121,33 +119,15 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
}
private def createLocalDirs(conf: SparkConf): Array[File] = {
- val dateFormat = new SimpleDateFormat("yyyyMMddHHmmss")
Utils.getOrCreateLocalRootDirs(conf).flatMap { rootDir =>
- var foundLocalDir = false
- var localDir: File = null
- var localDirId: String = null
- var tries = 0
- val rand = new Random()
- while (!foundLocalDir && tries < MAX_DIR_CREATION_ATTEMPTS) {
- tries += 1
- try {
- localDirId = "%s-%04x".format(dateFormat.format(new Date), rand.nextInt(65536))
- localDir = new File(rootDir, s"spark-local-$localDirId")
- if (!localDir.exists) {
- foundLocalDir = localDir.mkdirs()
- }
- } catch {
- case e: Exception =>
- logWarning(s"Attempt $tries to create local dir $localDir failed", e)
- }
- }
- if (!foundLocalDir) {
- logError(s"Failed $MAX_DIR_CREATION_ATTEMPTS attempts to create local dir in $rootDir." +
- " Ignoring this directory.")
- None
- } else {
+ try {
+ val localDir = Utils.createDirectory(rootDir, "blockmgr")
logInfo(s"Created local directory at $localDir")
Some(localDir)
+ } catch {
+ case e: IOException =>
+ logError(s"Failed to create local dir in $rootDir. Ignoring this directory.", e)
+ None
}
}
}
@@ -164,7 +144,7 @@ private[spark] class DiskBlockManager(blockManager: BlockManager, conf: SparkCon
/** Cleanup local dirs and stop shuffle sender. */
private[spark] def stop() {
// Only perform cleanup if an external service is not serving our shuffle files.
- if (!blockManager.externalShuffleServiceEnabled) {
+ if (!blockManager.externalShuffleServiceEnabled || blockManager.blockManagerId.isDriver) {
localDirs.foreach { localDir =>
if (localDir.isDirectory() && localDir.exists()) {
try {
diff --git a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
index 1e579187e4193..2499c11a65b0e 100644
--- a/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
+++ b/core/src/main/scala/org/apache/spark/storage/ShuffleBlockFetcherIterator.scala
@@ -17,6 +17,7 @@
package org.apache.spark.storage
+import java.io.{InputStream, IOException}
import java.util.concurrent.LinkedBlockingQueue
import scala.collection.mutable.{ArrayBuffer, HashSet, Queue}
@@ -92,7 +93,7 @@ final class ShuffleBlockFetcherIterator(
* Current [[FetchResult]] being processed. We track this so we can release the current buffer
* in case of a runtime exception when processing the current buffer.
*/
- private[this] var currentResult: FetchResult = null
+ @volatile private[this] var currentResult: FetchResult = null
/**
* Queue of fetch requests to issue; we'll pull requests off this gradually to make sure that
@@ -265,7 +266,7 @@ final class ShuffleBlockFetcherIterator(
// Get Local Blocks
fetchLocalBlocks()
- logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime) + " ms")
+ logDebug("Got local blocks in " + Utils.getUsedTimeMs(startTime))
}
override def hasNext: Boolean = numBlocksProcessed < numBlocksToFetch
@@ -289,17 +290,22 @@ final class ShuffleBlockFetcherIterator(
}
val iteratorTry: Try[Iterator[Any]] = result match {
- case FailureFetchResult(_, e) => Failure(e)
- case SuccessFetchResult(blockId, _, buf) => {
- val is = blockManager.wrapForCompression(blockId, buf.createInputStream())
- val iter = serializer.newInstance().deserializeStream(is).asIterator
- Success(CompletionIterator[Any, Iterator[Any]](iter, {
- // Once the iterator is exhausted, release the buffer and set currentResult to null
- // so we don't release it again in cleanup.
- currentResult = null
- buf.release()
- }))
- }
+ case FailureFetchResult(_, e) =>
+ Failure(e)
+ case SuccessFetchResult(blockId, _, buf) =>
+ // There is a chance that createInputStream can fail (e.g. fetching a local file that does
+ // not exist, SPARK-4085). In that case, we should propagate the right exception so
+ // the scheduler gets a FetchFailedException.
+ Try(buf.createInputStream()).map { is0 =>
+ val is = blockManager.wrapForCompression(blockId, is0)
+ val iter = serializer.newInstance().deserializeStream(is).asIterator
+ CompletionIterator[Any, Iterator[Any]](iter, {
+ // Once the iterator is exhausted, release the buffer and set currentResult to null
+ // so we don't release it again in cleanup.
+ currentResult = null
+ buf.release()
+ })
+ }
}
(result.blockId, iteratorTry)
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
index 6908a59a79e60..af873034215a9 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonBlockManager.scala
@@ -148,6 +148,7 @@ private[spark] class TachyonBlockManager(
logError("Exception while deleting tachyon spark dir: " + tachyonDir, e)
}
}
+ client.close()
}
})
}
diff --git a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
index 6dbad5ff0518e..233d1e2b7c616 100644
--- a/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
+++ b/core/src/main/scala/org/apache/spark/storage/TachyonStore.scala
@@ -116,6 +116,8 @@ private[spark] class TachyonStore(
case ioe: IOException =>
logWarning(s"Failed to fetch the block $blockId from Tachyon", ioe)
None
+ } finally {
+ is.close()
}
}
diff --git a/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
new file mode 100644
index 0000000000000..27ba9e18237b5
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/ConsoleProgressBar.scala
@@ -0,0 +1,124 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui
+
+import java.util.{Timer, TimerTask}
+
+import org.apache.spark._
+
+/**
+ * ConsoleProgressBar shows the progress of stages in the next line of the console. It poll the
+ * status of active stages from `sc.statusTracker` periodically, the progress bar will be showed
+ * up after the stage has ran at least 500ms. If multiple stages run in the same time, the status
+ * of them will be combined together, showed in one line.
+ */
+private[spark] class ConsoleProgressBar(sc: SparkContext) extends Logging {
+
+ // Carrige return
+ val CR = '\r'
+ // Update period of progress bar, in milliseconds
+ val UPDATE_PERIOD = 200L
+ // Delay to show up a progress bar, in milliseconds
+ val FIRST_DELAY = 500L
+
+ // The width of terminal
+ val TerminalWidth = if (!sys.env.getOrElse("COLUMNS", "").isEmpty) {
+ sys.env.get("COLUMNS").get.toInt
+ } else {
+ 80
+ }
+
+ var lastFinishTime = 0L
+ var lastUpdateTime = 0L
+ var lastProgressBar = ""
+
+ // Schedule a refresh thread to run periodically
+ private val timer = new Timer("refresh progress", true)
+ timer.schedule(new TimerTask{
+ override def run() {
+ refresh()
+ }
+ }, FIRST_DELAY, UPDATE_PERIOD)
+
+ /**
+ * Try to refresh the progress bar in every cycle
+ */
+ private def refresh(): Unit = synchronized {
+ val now = System.currentTimeMillis()
+ if (now - lastFinishTime < FIRST_DELAY) {
+ return
+ }
+ val stageIds = sc.statusTracker.getActiveStageIds()
+ val stages = stageIds.map(sc.statusTracker.getStageInfo).flatten.filter(_.numTasks() > 1)
+ .filter(now - _.submissionTime() > FIRST_DELAY).sortBy(_.stageId())
+ if (stages.size > 0) {
+ show(now, stages.take(3)) // display at most 3 stages in same time
+ }
+ }
+
+ /**
+ * Show progress bar in console. The progress bar is displayed in the next line
+ * after your last output, keeps overwriting itself to hold in one line. The logging will follow
+ * the progress bar, then progress bar will be showed in next line without overwrite logs.
+ */
+ private def show(now: Long, stages: Seq[SparkStageInfo]) {
+ val width = TerminalWidth / stages.size
+ val bar = stages.map { s =>
+ val total = s.numTasks()
+ val header = s"[Stage ${s.stageId()}:"
+ val tailer = s"(${s.numCompletedTasks()} + ${s.numActiveTasks()}) / $total]"
+ val w = width - header.size - tailer.size
+ val bar = if (w > 0) {
+ val percent = w * s.numCompletedTasks() / total
+ (0 until w).map { i =>
+ if (i < percent) "=" else if (i == percent) ">" else " "
+ }.mkString("")
+ } else {
+ ""
+ }
+ header + bar + tailer
+ }.mkString("")
+
+ // only refresh if it's changed of after 1 minute (or the ssh connection will be closed
+ // after idle some time)
+ if (bar != lastProgressBar || now - lastUpdateTime > 60 * 1000L) {
+ System.err.print(CR + bar)
+ lastUpdateTime = now
+ }
+ lastProgressBar = bar
+ }
+
+ /**
+ * Clear the progress bar if showed.
+ */
+ private def clear() {
+ if (!lastProgressBar.isEmpty) {
+ System.err.printf(CR + " " * TerminalWidth + CR)
+ lastProgressBar = ""
+ }
+ }
+
+ /**
+ * Mark all the stages as finished, clear the progress bar if showed, then the progress will not
+ * interweave with output of jobs.
+ */
+ def finishAll(): Unit = synchronized {
+ clear()
+ lastFinishTime = System.currentTimeMillis()
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
index 2a27d49d2de05..88fed833f922d 100644
--- a/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/JettyUtils.scala
@@ -201,7 +201,7 @@ private[spark] object JettyUtils extends Logging {
}
}
- val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, serverName)
+ val (server, boundPort) = Utils.startServiceOnPort[Server](port, connect, conf, serverName)
ServerInfo(server, boundPort, collection)
}
diff --git a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
index 049938f827291..0c24ad2760e08 100644
--- a/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
+++ b/core/src/main/scala/org/apache/spark/ui/SparkUI.scala
@@ -23,7 +23,7 @@ import org.apache.spark.storage.StorageStatusListener
import org.apache.spark.ui.JettyUtils._
import org.apache.spark.ui.env.{EnvironmentListener, EnvironmentTab}
import org.apache.spark.ui.exec.{ExecutorsListener, ExecutorsTab}
-import org.apache.spark.ui.jobs.{JobProgressListener, JobProgressTab}
+import org.apache.spark.ui.jobs.{JobsTab, JobProgressListener, StagesTab}
import org.apache.spark.ui.storage.{StorageListener, StorageTab}
/**
@@ -43,19 +43,20 @@ private[spark] class SparkUI private (
extends WebUI(securityManager, SparkUI.getUIPort(conf), conf, basePath, "SparkUI")
with Logging {
+ val killEnabled = sc.map(_.conf.getBoolean("spark.ui.killEnabled", true)).getOrElse(false)
+
/** Initialize all components of the server. */
def initialize() {
- val jobProgressTab = new JobProgressTab(this)
- attachTab(jobProgressTab)
+ attachTab(new JobsTab(this))
+ val stagesTab = new StagesTab(this)
+ attachTab(stagesTab)
attachTab(new StorageTab(this))
attachTab(new EnvironmentTab(this))
attachTab(new ExecutorsTab(this))
attachHandler(createStaticHandler(SparkUI.STATIC_RESOURCE_DIR, "/static"))
- attachHandler(createRedirectHandler("/", "/stages", basePath = basePath))
+ attachHandler(createRedirectHandler("/", "/jobs", basePath = basePath))
attachHandler(
- createRedirectHandler("/stages/stage/kill", "/stages", jobProgressTab.handleKillRequest))
- // If the UI is live, then serve
- sc.foreach { _.env.metricsSystem.getServletHandlers.foreach(attachHandler) }
+ createRedirectHandler("/stages/stage/kill", "/stages", stagesTab.handleKillRequest))
}
initialize()
diff --git a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
index f02904df31fcf..6f446c5a95a0a 100644
--- a/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
+++ b/core/src/main/scala/org/apache/spark/ui/ToolTips.scala
@@ -24,8 +24,13 @@ private[spark] object ToolTips {
scheduler delay is large, consider decreasing the size of tasks or decreasing the size
of task results."""
+ val TASK_DESERIALIZATION_TIME =
+ """Time spent deserializating the task closure on the executor."""
+
val INPUT = "Bytes read from Hadoop or from Spark storage."
+ val OUTPUT = "Bytes written to Hadoop."
+
val SHUFFLE_WRITE = "Bytes written to disk in order to be read by a shuffle in a future stage."
val SHUFFLE_READ =
diff --git a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
index 3312671b6f885..b5022fe853c49 100644
--- a/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
+++ b/core/src/main/scala/org/apache/spark/ui/UIUtils.scala
@@ -26,7 +26,8 @@ import org.apache.spark.Logging
/** Utility functions for generating XML pages with spark content. */
private[spark] object UIUtils extends Logging {
- val TABLE_CLASS = "table table-bordered table-striped-custom table-condensed sortable"
+ val TABLE_CLASS_NOT_STRIPED = "table table-bordered table-condensed sortable"
+ val TABLE_CLASS_STRIPED = TABLE_CLASS_NOT_STRIPED + " table-striped"
// SimpleDateFormat is not thread-safe. Don't expose it to avoid improper use.
private val dateFormat = new ThreadLocal[SimpleDateFormat]() {
@@ -169,15 +170,21 @@ private[spark] object UIUtils extends Logging {
title: String,
content: => Seq[Node],
activeTab: SparkUITab,
- refreshInterval: Option[Int] = None): Seq[Node] = {
+ refreshInterval: Option[Int] = None,
+ helpText: Option[String] = None): Seq[Node] = {
val appName = activeTab.appName
val shortAppName = if (appName.length < 36) appName else appName.take(32) + "..."
val header = activeTab.headerTabs.map { tab =>
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
new file mode 100644
index 0000000000000..c82730f524eb7
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorThreadDumpPage.scala
@@ -0,0 +1,86 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.exec
+
+import java.net.URLDecoder
+import javax.servlet.http.HttpServletRequest
+
+import scala.util.Try
+import scala.xml.{Text, Node}
+
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+private[ui] class ExecutorThreadDumpPage(parent: ExecutorsTab) extends WebUIPage("threadDump") {
+
+ private val sc = parent.sc
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ val executorId = Option(request.getParameter("executorId")).map {
+ executorId =>
+ // Due to YARN-2844, "" in the url will be encoded to "%25253Cdriver%25253E" when
+ // running in yarn-cluster mode. `request.getParameter("executorId")` will return
+ // "%253Cdriver%253E". Therefore we need to decode it until we get the real id.
+ var id = executorId
+ var decodedId = URLDecoder.decode(id, "UTF-8")
+ while (id != decodedId) {
+ id = decodedId
+ decodedId = URLDecoder.decode(id, "UTF-8")
+ }
+ id
+ }.getOrElse {
+ return Text(s"Missing executorId parameter")
+ }
+ val time = System.currentTimeMillis()
+ val maybeThreadDump = sc.get.getExecutorThreadDump(executorId)
+
+ val content = maybeThreadDump.map { threadDump =>
+ val dumpRows = threadDump.map { thread =>
+
+ } else {
+ Seq.empty
+ }
+ }
}
diff --git a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
index 9e0e71a51a408..dd1c2b78c4094 100644
--- a/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
+++ b/core/src/main/scala/org/apache/spark/ui/exec/ExecutorsTab.scala
@@ -27,8 +27,14 @@ import org.apache.spark.ui.{SparkUI, SparkUITab}
private[ui] class ExecutorsTab(parent: SparkUI) extends SparkUITab(parent, "executors") {
val listener = parent.executorsListener
+ val sc = parent.sc
+ val threadDumpEnabled =
+ sc.isDefined && parent.conf.getBoolean("spark.ui.threadDumpsEnabled", true)
- attachPage(new ExecutorsPage(this))
+ attachPage(new ExecutorsPage(this, threadDumpEnabled))
+ if (threadDumpEnabled) {
+ attachPage(new ExecutorThreadDumpPage(this))
+ }
}
/**
@@ -42,6 +48,7 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
val executorToTasksFailed = HashMap[String, Int]()
val executorToDuration = HashMap[String, Long]()
val executorToInputBytes = HashMap[String, Long]()
+ val executorToOutputBytes = HashMap[String, Long]()
val executorToShuffleRead = HashMap[String, Long]()
val executorToShuffleWrite = HashMap[String, Long]()
@@ -72,6 +79,10 @@ class ExecutorsListener(storageStatusListener: StorageStatusListener) extends Sp
executorToInputBytes(eid) =
executorToInputBytes.getOrElse(eid, 0L) + inputMetrics.bytesRead
}
+ metrics.outputMetrics.foreach { outputMetrics =>
+ executorToOutputBytes(eid) =
+ executorToOutputBytes.getOrElse(eid, 0L) + outputMetrics.bytesWritten
+ }
metrics.shuffleReadMetrics.foreach { shuffleRead =>
executorToShuffleRead(eid) =
executorToShuffleRead.getOrElse(eid, 0L) + shuffleRead.remoteBytesRead
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
new file mode 100644
index 0000000000000..ea2d187a0e8e4
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllJobsPage.scala
@@ -0,0 +1,151 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.xml.{Node, NodeSeq}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.JobExecutionStatus
+import org.apache.spark.ui.{WebUIPage, UIUtils}
+import org.apache.spark.ui.jobs.UIData.JobUIData
+
+/** Page showing list of all ongoing and recently finished jobs */
+private[ui] class AllJobsPage(parent: JobsTab) extends WebUIPage("") {
+ private val startTime: Option[Long] = parent.sc.map(_.startTime)
+ private val listener = parent.listener
+
+ private def jobsTable(jobs: Seq[JobUIData]): Seq[Node] = {
+ val someJobHasJobGroup = jobs.exists(_.jobGroup.isDefined)
+
+ val columns: Seq[Node] = {
+
{if (someJobHasJobGroup) "Job Id (Job Group)" else "Job Id"}
+
Description
+
Submitted
+
Duration
+
Stages: Succeeded/Total
+
Tasks (for all stages): Succeeded/Total
+ }
+
+ def makeRow(job: JobUIData): Seq[Node] = {
+ val lastStageInfo = listener.stageIdToInfo.get(job.stageIds.max)
+ val lastStageData = lastStageInfo.flatMap { s =>
+ listener.stageIdToData.get((s.stageId, s.attemptId))
+ }
+ val isComplete = job.status == JobExecutionStatus.SUCCEEDED
+ val lastStageName = lastStageInfo.map(_.name).getOrElse("(Unknown Stage Name)")
+ val lastStageDescription = lastStageData.flatMap(_.description).getOrElse("")
+ val duration: Option[Long] = {
+ job.startTime.map { start =>
+ val end = job.endTime.getOrElse(System.currentTimeMillis())
+ end - start
+ }
+ }
+ val formattedDuration = duration.map(d => UIUtils.formatDuration(d)).getOrElse("Unknown")
+ val formattedSubmissionTime = job.startTime.map(UIUtils.formatDate).getOrElse("Unknown")
+ val detailUrl =
+ "%s/jobs/job?id=%s".format(UIUtils.prependBaseUri(parent.basePath), job.jobId)
+
++ failedJobsTable
+
+ val helpText = """A job is triggered by a action, like "count()" or "saveAsTextFile()".""" +
+ " Click on a job's title to see information about the stages of tasks associated with" +
+ " the job."
+
+ UIUtils.headerSparkPage("Spark Jobs", content, parent, helpText = Some(helpText))
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
similarity index 80%
rename from core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
rename to core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
index 6e718eecdd52a..b0f8ca2ab0d3f 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/AllStagesPage.scala
@@ -25,7 +25,7 @@ import org.apache.spark.scheduler.Schedulable
import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing list of all ongoing and recently finished stages and pools */
-private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("") {
+private[ui] class AllStagesPage(parent: StagesTab) extends WebUIPage("") {
private val sc = parent.sc
private val listener = parent.listener
private def isFairScheduler = parent.isFairScheduler
@@ -34,16 +34,21 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
listener.synchronized {
val activeStages = listener.activeStages.values.toSeq
val completedStages = listener.completedStages.reverse.toSeq
+ val numCompletedStages = listener.numCompletedStages
val failedStages = listener.failedStages.reverse.toSeq
+ val numFailedStages = listener.numFailedStages
val now = System.currentTimeMillis
val activeStagesTable =
new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
- parent, parent.killEnabled)
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
val completedStagesTable =
- new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent)
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
val failedStagesTable =
- new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent)
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
// For now, pool information is only accessible in live UIs
val pools = sc.map(_.getAllPools).getOrElse(Seq.empty[Schedulable])
@@ -69,11 +74,11 @@ private[ui] class JobProgressPage(parent: JobProgressTab) extends WebUIPage("")
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
new file mode 100644
index 0000000000000..77d36209c6048
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobPage.scala
@@ -0,0 +1,177 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import scala.collection.mutable
+import scala.xml.{NodeSeq, Node}
+
+import javax.servlet.http.HttpServletRequest
+
+import org.apache.spark.JobExecutionStatus
+import org.apache.spark.scheduler.StageInfo
+import org.apache.spark.ui.{UIUtils, WebUIPage}
+
+/** Page showing statistics and stage list for a given job */
+private[ui] class JobPage(parent: JobsTab) extends WebUIPage("job") {
+ private val listener = parent.listener
+
+ def render(request: HttpServletRequest): Seq[Node] = {
+ listener.synchronized {
+ val jobId = request.getParameter("id").toInt
+ val jobDataOption = listener.jobIdToData.get(jobId)
+ if (jobDataOption.isEmpty) {
+ val content =
+
+
No information to display for job {jobId}
+
+ return UIUtils.headerSparkPage(
+ s"Details for Job $jobId", content, parent)
+ }
+ val jobData = jobDataOption.get
+ val isComplete = jobData.status != JobExecutionStatus.RUNNING
+ val stages = jobData.stageIds.map { stageId =>
+ // This could be empty if the JobProgressListener hasn't received information about the
+ // stage or if the stage information has been garbage collected
+ listener.stageIdToInfo.getOrElse(stageId,
+ new StageInfo(stageId, 0, "Unknown", 0, Seq.empty, "Unknown"))
+ }
+
+ val activeStages = mutable.Buffer[StageInfo]()
+ val completedStages = mutable.Buffer[StageInfo]()
+ // If the job is completed, then any pending stages are displayed as "skipped":
+ val pendingOrSkippedStages = mutable.Buffer[StageInfo]()
+ val failedStages = mutable.Buffer[StageInfo]()
+ for (stage <- stages) {
+ if (stage.submissionTime.isEmpty) {
+ pendingOrSkippedStages += stage
+ } else if (stage.completionTime.isDefined) {
+ if (stage.failureReason.isDefined) {
+ failedStages += stage
+ } else {
+ completedStages += stage
+ }
+ } else {
+ activeStages += stage
+ }
+ }
+
+ val activeStagesTable =
+ new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
+ val pendingOrSkippedStagesTable =
+ new StageTableBase(pendingOrSkippedStages.sortBy(_.stageId).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = false)
+ val completedStagesTable =
+ new StageTableBase(completedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler, killEnabled = false)
+ val failedStagesTable =
+ new FailedStageTable(failedStages.sortBy(_.submissionTime).reverse, parent.basePath,
+ parent.listener, isFairScheduler = parent.isFairScheduler)
+
+ val shouldShowActiveStages = activeStages.nonEmpty
+ val shouldShowPendingStages = !isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowCompletedStages = completedStages.nonEmpty
+ val shouldShowSkippedStages = isComplete && pendingOrSkippedStages.nonEmpty
+ val shouldShowFailedStages = failedStages.nonEmpty
+
+ val summary: NodeSeq =
+
++
+ failedStagesTable.toNodeSeq
+ }
+ UIUtils.headerSparkPage(s"Details for Job $jobId", content, parent)
+ }
+ }
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
index b5207360510dd..72935beb3a34a 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobProgressListener.scala
@@ -17,7 +17,7 @@
package org.apache.spark.ui.jobs
-import scala.collection.mutable.{HashMap, ListBuffer}
+import scala.collection.mutable.{HashMap, HashSet, ListBuffer}
import org.apache.spark._
import org.apache.spark.annotation.DeveloperApi
@@ -40,41 +40,145 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
import JobProgressListener._
+ // Define a handful of type aliases so that data structures' types can serve as documentation.
+ // These type aliases are public because they're used in the types of public fields:
+
type JobId = Int
type StageId = Int
type StageAttemptId = Int
+ type PoolName = String
+ type ExecutorId = String
- // How many stages to remember
- val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
- // How many jobs to remember
- val retailedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS)
-
+ // Jobs:
val activeJobs = new HashMap[JobId, JobUIData]
val completedJobs = ListBuffer[JobUIData]()
val failedJobs = ListBuffer[JobUIData]()
val jobIdToData = new HashMap[JobId, JobUIData]
+ // Stages:
val activeStages = new HashMap[StageId, StageInfo]
val completedStages = ListBuffer[StageInfo]()
+ val skippedStages = ListBuffer[StageInfo]()
val failedStages = ListBuffer[StageInfo]()
val stageIdToData = new HashMap[(StageId, StageAttemptId), StageUIData]
val stageIdToInfo = new HashMap[StageId, StageInfo]
+ val stageIdToActiveJobIds = new HashMap[StageId, HashSet[JobId]]
+ val poolToActiveStages = HashMap[PoolName, HashMap[StageId, StageInfo]]()
+ // Total of completed and failed stages that have ever been run. These may be greater than
+ // `completedStages.size` and `failedStages.size` if we have run more stages or jobs than
+ // JobProgressListener's retention limits.
+ var numCompletedStages = 0
+ var numFailedStages = 0
+
+ // Misc:
+ val executorIdToBlockManagerId = HashMap[ExecutorId, BlockManagerId]()
+ def blockManagerIds = executorIdToBlockManagerId.values.toSeq
- // Map from pool name to a hash map (map from stage id to StageInfo).
- val poolToActiveStages = HashMap[String, HashMap[Int, StageInfo]]()
+ var schedulingMode: Option[SchedulingMode] = None
- val executorIdToBlockManagerId = HashMap[String, BlockManagerId]()
+ // To limit the total memory usage of JobProgressListener, we only track information for a fixed
+ // number of non-active jobs and stages (there is no limit for active jobs and stages):
- var schedulingMode: Option[SchedulingMode] = None
+ val retainedStages = conf.getInt("spark.ui.retainedStages", DEFAULT_RETAINED_STAGES)
+ val retainedJobs = conf.getInt("spark.ui.retainedJobs", DEFAULT_RETAINED_JOBS)
+
+ // We can test for memory leaks by ensuring that collections that track non-active jobs and
+ // stages do not grow without bound and that collections for active jobs/stages eventually become
+ // empty once Spark is idle. Let's partition our collections into ones that should be empty
+ // once Spark is idle and ones that should have a hard- or soft-limited sizes.
+ // These methods are used by unit tests, but they're defined here so that people don't forget to
+ // update the tests when adding new collections. Some collections have multiple levels of
+ // nesting, etc, so this lets us customize our notion of "size" for each structure:
+
+ // These collections should all be empty once Spark is idle (no active stages / jobs):
+ private[spark] def getSizesOfActiveStateTrackingCollections: Map[String, Int] = {
+ Map(
+ "activeStages" -> activeStages.size,
+ "activeJobs" -> activeJobs.size,
+ "poolToActiveStages" -> poolToActiveStages.values.map(_.size).sum,
+ "stageIdToActiveJobIds" -> stageIdToActiveJobIds.values.map(_.size).sum
+ )
+ }
- def blockManagerIds = executorIdToBlockManagerId.values.toSeq
+ // These collections should stop growing once we have run at least `spark.ui.retainedStages`
+ // stages and `spark.ui.retainedJobs` jobs:
+ private[spark] def getSizesOfHardSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "completedJobs" -> completedJobs.size,
+ "failedJobs" -> failedJobs.size,
+ "completedStages" -> completedStages.size,
+ "skippedStages" -> skippedStages.size,
+ "failedStages" -> failedStages.size
+ )
+ }
+
+ // These collections may grow arbitrarily, but once Spark becomes idle they should shrink back to
+ // some bound based on the `spark.ui.retainedStages` and `spark.ui.retainedJobs` settings:
+ private[spark] def getSizesOfSoftSizeLimitedCollections: Map[String, Int] = {
+ Map(
+ "jobIdToData" -> jobIdToData.size,
+ "stageIdToData" -> stageIdToData.size,
+ "stageIdToStageInfo" -> stageIdToInfo.size
+ )
+ }
+
+ /** If stages is too large, remove and garbage collect old stages */
+ private def trimStagesIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
+ if (stages.size > retainedStages) {
+ val toRemove = math.max(retainedStages / 10, 1)
+ stages.take(toRemove).foreach { s =>
+ stageIdToData.remove((s.stageId, s.attemptId))
+ stageIdToInfo.remove(s.stageId)
+ }
+ stages.trimStart(toRemove)
+ }
+ }
+
+ /** If jobs is too large, remove and garbage collect old jobs */
+ private def trimJobsIfNecessary(jobs: ListBuffer[JobUIData]) = synchronized {
+ if (jobs.size > retainedJobs) {
+ val toRemove = math.max(retainedJobs / 10, 1)
+ jobs.take(toRemove).foreach { job =>
+ jobIdToData.remove(job.jobId)
+ }
+ jobs.trimStart(toRemove)
+ }
+ }
override def onJobStart(jobStart: SparkListenerJobStart) = synchronized {
- val jobGroup = Option(jobStart.properties).map(_.getProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ val jobGroup = for (
+ props <- Option(jobStart.properties);
+ group <- Option(props.getProperty(SparkContext.SPARK_JOB_GROUP_ID))
+ ) yield group
val jobData: JobUIData =
- new JobUIData(jobStart.jobId, jobStart.stageIds, jobGroup, JobExecutionStatus.RUNNING)
+ new JobUIData(
+ jobId = jobStart.jobId,
+ startTime = Some(System.currentTimeMillis),
+ endTime = None,
+ stageIds = jobStart.stageIds,
+ jobGroup = jobGroup,
+ status = JobExecutionStatus.RUNNING)
+ // Compute (a potential underestimate of) the number of tasks that will be run by this job.
+ // This may be an underestimate because the job start event references all of the result
+ // stages's transitive stage dependencies, but some of these stages might be skipped if their
+ // output is available from earlier runs.
+ // See https://github.com/apache/spark/pull/3009 for a more extensive discussion.
+ jobData.numTasks = {
+ val allStages = jobStart.stageInfos
+ val missingStages = allStages.filter(_.completionTime.isEmpty)
+ missingStages.map(_.numTasks).sum
+ }
jobIdToData(jobStart.jobId) = jobData
activeJobs(jobStart.jobId) = jobData
+ for (stageId <- jobStart.stageIds) {
+ stageIdToActiveJobIds.getOrElseUpdate(stageId, new HashSet[StageId]).add(jobStart.jobId)
+ }
+ // If there's no information for a stage, store the StageInfo received from the scheduler
+ // so that we can display stage descriptions for pending stages:
+ for (stageInfo <- jobStart.stageInfos) {
+ stageIdToInfo.getOrElseUpdate(stageInfo.stageId, stageInfo)
+ stageIdToData.getOrElseUpdate((stageInfo.stageId, stageInfo.attemptId), new StageUIData)
+ }
}
override def onJobEnd(jobEnd: SparkListenerJobEnd) = synchronized {
@@ -82,14 +186,31 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
logWarning(s"Job completed for unknown job ${jobEnd.jobId}")
new JobUIData(jobId = jobEnd.jobId)
}
+ jobData.endTime = Some(System.currentTimeMillis())
jobEnd.jobResult match {
case JobSucceeded =>
completedJobs += jobData
+ trimJobsIfNecessary(completedJobs)
jobData.status = JobExecutionStatus.SUCCEEDED
case JobFailed(exception) =>
failedJobs += jobData
+ trimJobsIfNecessary(failedJobs)
jobData.status = JobExecutionStatus.FAILED
}
+ for (stageId <- jobData.stageIds) {
+ stageIdToActiveJobIds.get(stageId).foreach { jobsUsingStage =>
+ jobsUsingStage.remove(jobEnd.jobId)
+ stageIdToInfo.get(stageId).foreach { stageInfo =>
+ if (stageInfo.submissionTime.isEmpty) {
+ // if this stage is pending, it won't complete, so mark it as "skipped":
+ skippedStages += stageInfo
+ trimStagesIfNecessary(skippedStages)
+ jobData.numSkippedStages += 1
+ jobData.numSkippedTasks += stageInfo.numTasks
+ }
+ }
+ }
+ }
}
override def onStageCompleted(stageCompleted: SparkListenerStageCompleted) = synchronized {
@@ -110,22 +231,25 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
activeStages.remove(stage.stageId)
if (stage.failureReason.isEmpty) {
completedStages += stage
- trimIfNecessary(completedStages)
+ numCompletedStages += 1
+ trimStagesIfNecessary(completedStages)
} else {
failedStages += stage
- trimIfNecessary(failedStages)
+ numFailedStages += 1
+ trimStagesIfNecessary(failedStages)
}
- }
- /** If stages is too large, remove and garbage collect old stages */
- private def trimIfNecessary(stages: ListBuffer[StageInfo]) = synchronized {
- if (stages.size > retainedStages) {
- val toRemove = math.max(retainedStages / 10, 1)
- stages.take(toRemove).foreach { s =>
- stageIdToData.remove((s.stageId, s.attemptId))
- stageIdToInfo.remove(s.stageId)
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages -= 1
+ if (stage.failureReason.isEmpty) {
+ jobData.completedStageIndices.add(stage.stageId)
+ } else {
+ jobData.numFailedStages += 1
}
- stages.trimStart(toRemove)
}
}
@@ -148,6 +272,14 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
val stages = poolToActiveStages.getOrElseUpdate(poolName, new HashMap[Int, StageInfo])
stages(stage.stageId) = stage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(stage.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveStages += 1
+ }
}
override def onTaskStart(taskStart: SparkListenerTaskStart) = synchronized {
@@ -160,6 +292,13 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.numActiveTasks += 1
stageData.taskData.put(taskInfo.taskId, new TaskUIData(taskInfo))
}
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskStart.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks += 1
+ }
}
override def onTaskGettingResult(taskGettingResult: SparkListenerTaskGettingResult) {
@@ -217,6 +356,20 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
taskData.taskInfo = info
taskData.taskMetrics = metrics
taskData.errorMessage = errorMessage
+
+ for (
+ activeJobsDependentOnStage <- stageIdToActiveJobIds.get(taskEnd.stageId);
+ jobId <- activeJobsDependentOnStage;
+ jobData <- jobIdToData.get(jobId)
+ ) {
+ jobData.numActiveTasks -= 1
+ taskEnd.reason match {
+ case Success =>
+ jobData.numCompletedTasks += 1
+ case _ =>
+ jobData.numFailedTasks += 1
+ }
+ }
}
}
@@ -250,6 +403,12 @@ class JobProgressListener(conf: SparkConf) extends SparkListener with Logging {
stageData.inputBytes += inputBytesDelta
execSummary.inputBytes += inputBytesDelta
+ val outputBytesDelta =
+ (taskMetrics.outputMetrics.map(_.bytesWritten).getOrElse(0L)
+ - oldMetrics.flatMap(_.outputMetrics).map(_.bytesWritten).getOrElse(0L))
+ stageData.outputBytes += outputBytesDelta
+ execSummary.outputBytes += outputBytesDelta
+
val diskSpillDelta =
taskMetrics.diskBytesSpilled - oldMetrics.map(_.diskBytesSpilled).getOrElse(0L)
stageData.diskBytesSpilled += diskSpillDelta
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
new file mode 100644
index 0000000000000..b2bbfdee56946
--- /dev/null
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/JobsTab.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.ui.jobs
+
+import org.apache.spark.scheduler.SchedulingMode
+import org.apache.spark.ui.{SparkUI, SparkUITab}
+
+/** Web UI showing progress status of all jobs in the given SparkContext. */
+private[ui] class JobsTab(parent: SparkUI) extends SparkUITab(parent, "jobs") {
+ val sc = parent.sc
+ val killEnabled = parent.killEnabled
+ def isFairScheduler = listener.schedulingMode.exists(_ == SchedulingMode.FAIR)
+ val listener = parent.jobProgressListener
+
+ attachPage(new AllJobsPage(this))
+ attachPage(new JobPage(this))
+}
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
index 770d99eea1c9d..5fc6cc7533150 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolPage.scala
@@ -25,7 +25,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.{WebUIPage, UIUtils}
/** Page showing specific pool details */
-private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
+private[ui] class PoolPage(parent: StagesTab) extends WebUIPage("pool") {
private val sc = parent.sc
private val listener = parent.listener
@@ -37,8 +37,9 @@ private[ui] class PoolPage(parent: JobProgressTab) extends WebUIPage("pool") {
case Some(s) => s.values.toSeq
case None => Seq[StageInfo]()
}
- val activeStagesTable =
- new StageTableBase(activeStages.sortBy(_.submissionTime).reverse, parent)
+ val activeStagesTable = new StageTableBase(activeStages.sortBy(_.submissionTime).reverse,
+ parent.basePath, parent.listener, isFairScheduler = parent.isFairScheduler,
+ killEnabled = parent.killEnabled)
// For now, pool information is only accessible in live UIs
val pools = sc.map(_.getPoolForName(poolName).get).toSeq
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
index 64178e1e33d41..df1899e7a9b84 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/PoolTable.scala
@@ -24,7 +24,7 @@ import org.apache.spark.scheduler.{Schedulable, StageInfo}
import org.apache.spark.ui.UIUtils
/** Table showing list of pools */
-private[ui] class PoolTable(pools: Seq[Schedulable], parent: JobProgressTab) {
+private[ui] class PoolTable(pools: Seq[Schedulable], parent: StagesTab) {
private val listener = parent.listener
def toNodeSeq: Seq[Node] = {
diff --git a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
index 7cc03b7d333df..09a936c2234c0 100644
--- a/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
+++ b/core/src/main/scala/org/apache/spark/ui/jobs/StagePage.scala
@@ -22,6 +22,8 @@ import javax.servlet.http.HttpServletRequest
import scala.xml.{Node, Unparsed}
+import org.apache.commons.lang3.StringEscapeUtils
+
import org.apache.spark.executor.TaskMetrics
import org.apache.spark.ui.{ToolTips, WebUIPage, UIUtils}
import org.apache.spark.ui.jobs.UIData._
@@ -29,7 +31,7 @@ import org.apache.spark.util.{Utils, Distribution}
import org.apache.spark.scheduler.{AccumulableInfo, TaskInfo}
/** Page showing statistics and task list for a given stage */
-private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
+private[ui] class StagePage(parent: StagesTab) extends WebUIPage("stage") {
private val listener = parent.listener
def render(request: HttpServletRequest): Seq[Node] = {
@@ -55,6 +57,7 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
val accumulables = listener.stageIdToData((stageId, stageAttemptId)).accumulables
val hasAccumulators = accumulables.size > 0
val hasInput = stageData.inputBytes > 0
+ val hasOutput = stageData.outputBytes > 0
val hasShuffleRead = stageData.shuffleReadBytes > 0
val hasShuffleWrite = stageData.shuffleWriteBytes > 0
val hasBytesSpilled = stageData.memoryBytesSpilled > 0 && stageData.diskBytesSpilled > 0
@@ -72,6 +75,12 @@ private[ui] class StagePage(parent: JobProgressTab) extends WebUIPage("stage") {
{Utils.bytesToString(stageData.inputBytes)}
}}
+ {if (hasOutput) {
+
++
@@ -43,6 +44,7 @@ private[ui] class StageTableBase(
Duration
Tasks: Succeeded/Total
Input
+
Output
Shuffle Read
-
-
-
-
-
-
-
-
-
-
-
-
-
-
- build.dir
- ${user.dir}/build
-
-
-
- build.dir.hive
- ${build.dir}/hive
-
-
-
- hadoop.tmp.dir
- ${build.dir.hive}/test/hadoop-${user.name}
- A base for other temporary directories.
-
-
-
-
-
- hive.exec.scratchdir
- ${build.dir}/scratchdir
- Scratch space for Hive jobs
-
-
-
- hive.exec.local.scratchdir
- ${build.dir}/localscratchdir/
- Local scratch space for Hive jobs
-
-
-
- javax.jdo.option.ConnectionURL
-
- jdbc:derby:;databaseName=../build/test/junit_metastore_db;create=true
-
-
-
- javax.jdo.option.ConnectionDriverName
- org.apache.derby.jdbc.EmbeddedDriver
-
-
-
- javax.jdo.option.ConnectionUserName
- APP
-
-
-
- javax.jdo.option.ConnectionPassword
- mine
-
-
-
-
- hive.metastore.warehouse.dir
- ${test.warehouse.dir}
-
-
-
-
- hive.metastore.metadb.dir
- ${build.dir}/test/data/metadb/
-
- Required by metastore server or if the uris argument below is not supplied
-
-
-
-
- test.log.dir
- ${build.dir}/test/logs
-
-
-
-
- test.src.dir
- ${build.dir}/src/test
-
-
-
-
-
-
- hive.jar.path
- ${build.dir.hive}/ql/hive-exec-${version}.jar
-
-
-
-
- hive.metastore.rawstore.impl
- org.apache.hadoop.hive.metastore.ObjectStore
- Name of the class that implements org.apache.hadoop.hive.metastore.rawstore interface. This class is used to store and retrieval of raw metadata objects such as table, database
-
-
-
- hive.querylog.location
- ${build.dir}/tmp
- Location of the structured hive logs
-
-
-
-
-
- hive.task.progress
- false
- Track progress of a task
-
-
-
- hive.support.concurrency
- false
- Whether hive supports concurrency or not. A zookeeper instance must be up and running for the default hive lock manager to support read-write locks.
-
-
-
- fs.pfile.impl
- org.apache.hadoop.fs.ProxyLocalFileSystem
- A proxy for local file system used for cross file system testing
-
-
-
- hive.exec.mode.local.auto
- false
-
- Let hive determine whether to run in local mode automatically
- Disabling this for tests so that minimr is not affected
-
-
-
-
- hive.auto.convert.join
- false
- Whether Hive enable the optimization about converting common join into mapjoin based on the input file size
-
-
-
- hive.ignore.mapjoin.hint
- false
- Whether Hive ignores the mapjoin hint
-
-
-
- hive.input.format
- org.apache.hadoop.hive.ql.io.CombineHiveInputFormat
- The default input format, if it is not specified, the system assigns it. It is set to HiveInputFormat for hadoop versions 17, 18 and 19, whereas it is set to CombineHiveInputFormat for hadoop 20. The user can always overwrite it - if there is a bug in CombineHiveInputFormat, it can always be manually set to HiveInputFormat.
-
-
-
- hive.default.rcfile.serde
- org.apache.hadoop.hive.serde2.columnar.ColumnarSerDe
- The default SerDe hive will use for the rcfile format
-
-
-
diff --git a/dev/change-version-to-2.10.sh b/dev/change-version-to-2.10.sh
new file mode 100755
index 0000000000000..7473c20d28e09
--- /dev/null
+++ b/dev/change-version-to-2.10.sh
@@ -0,0 +1,20 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.11|\1_2.10|g' {}
diff --git a/dev/change-version-to-2.11.sh b/dev/change-version-to-2.11.sh
new file mode 100755
index 0000000000000..3957a9f3ba258
--- /dev/null
+++ b/dev/change-version-to-2.11.sh
@@ -0,0 +1,21 @@
+#!/usr/bin/env bash
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+find . -name 'pom.xml' | grep -v target \
+ | xargs -I {} sed -i -e 's|\(artifactId.*\)_2.10|\1_2.11|g' {}
diff --git a/dev/create-release/create-release.sh b/dev/create-release/create-release.sh
index 281e8d4de6d71..71312e053c457 100755
--- a/dev/create-release/create-release.sh
+++ b/dev/create-release/create-release.sh
@@ -27,13 +27,19 @@
# Would be nice to add:
# - Send output to stderr and have useful logging in stdout
-GIT_USERNAME=${GIT_USERNAME:-pwendell}
-GIT_PASSWORD=${GIT_PASSWORD:-XXX}
+# Note: The following variables must be set before use!
+ASF_USERNAME=${ASF_USERNAME:-pwendell}
+ASF_PASSWORD=${ASF_PASSWORD:-XXX}
GPG_PASSPHRASE=${GPG_PASSPHRASE:-XXX}
GIT_BRANCH=${GIT_BRANCH:-branch-1.0}
-RELEASE_VERSION=${RELEASE_VERSION:-1.0.0}
+RELEASE_VERSION=${RELEASE_VERSION:-1.2.0}
+NEXT_VERSION=${NEXT_VERSION:-1.2.1}
RC_NAME=${RC_NAME:-rc2}
-USER_NAME=${USER_NAME:-pwendell}
+
+M2_REPO=~/.m2/repository
+SPARK_REPO=$M2_REPO/org/apache/spark
+NEXUS_ROOT=https://repository.apache.org/service/local/staging
+NEXUS_PROFILE=d63f592e7eac0 # Profile for Spark staging uploads
if [ -z "$JAVA_HOME" ]; then
echo "Error: JAVA_HOME is not set, cannot proceed."
@@ -46,31 +52,106 @@ set -e
GIT_TAG=v$RELEASE_VERSION-$RC_NAME
if [[ ! "$@" =~ --package-only ]]; then
- echo "Creating and publishing release"
+ echo "Creating release commit and publishing to Apache repository"
# Artifact publishing
- git clone https://git-wip-us.apache.org/repos/asf/spark.git -b $GIT_BRANCH
- cd spark
+ git clone https://$ASF_USERNAME:$ASF_PASSWORD@git-wip-us.apache.org/repos/asf/spark.git \
+ -b $GIT_BRANCH
+ pushd spark
export MAVEN_OPTS="-Xmx3g -XX:MaxPermSize=1g -XX:ReservedCodeCacheSize=1g"
- mvn -Pyarn release:clean
-
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dusername=$GIT_USERNAME -Dpassword=$GIT_PASSWORD \
- -Dmaven.javadoc.skip=true \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dtag=$GIT_TAG -DautoVersionSubmodules=true \
- -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- --batch-mode release:prepare
-
- mvn -DskipTests \
- -Darguments="-DskipTests=true -Dmaven.javadoc.skip=true -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 -Dgpg.passphrase=${GPG_PASSPHRASE}" \
- -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
- -Dmaven.javadoc.skip=true \
+ # Create release commits and push them to github
+ # NOTE: This is done "eagerly" i.e. we don't check if we can succesfully build
+ # or before we coin the release commit. This helps avoid races where
+ # other people add commits to this branch while we are in the middle of building.
+ cur_ver="${RELEASE_VERSION}-SNAPSHOT"
+ rel_ver="${RELEASE_VERSION}"
+ next_ver="${NEXT_VERSION}-SNAPSHOT"
+
+ old="^\( \{2,4\}\)${cur_ver}<\/version>$"
+ new="\1${rel_ver}<\/version>"
+ find . -name pom.xml | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+ find . -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+
+ git commit -a -m "Preparing Spark release $GIT_TAG"
+ echo "Creating tag $GIT_TAG at the head of $GIT_BRANCH"
+ git tag $GIT_TAG
+
+ old="^\( \{2,4\}\)${rel_ver}<\/version>$"
+ new="\1${next_ver}<\/version>"
+ find . -name pom.xml | grep -v dev | xargs -I {} sed -i \
+ -e "s/$old/$new/" {}
+ find . -name package.scala | grep -v dev | xargs -I {} sed -i \
+ -e "s/${old}/${new}/" {}
+ git commit -a -m "Preparing development version $next_ver"
+ git push origin $GIT_TAG
+ git push origin HEAD:$GIT_BRANCH
+ git checkout -f $GIT_TAG
+
+ # Using Nexus API documented here:
+ # https://support.sonatype.com/entries/39720203-Uploading-to-a-Staging-Repository-via-REST-API
+ echo "Creating Nexus staging repository"
+ repo_request="Apache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/start)
+ staged_repo_id=$(echo $out | sed -e "s/.*\(orgapachespark-[0-9]\{4\}\).*/\1/")
+ echo "Created Nexus staging repository: $staged_repo_id"
+
+ rm -rf $SPARK_REPO
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
-Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
- release:perform
+ clean install
- cd ..
+ ./dev/change-version-to-2.11.sh
+
+ mvn -DskipTests -Dhadoop.version=2.2.0 -Dyarn.version=2.2.0 \
+ -Dscala-2.11 -Pyarn -Phive -Phadoop-2.2 -Pspark-ganglia-lgpl -Pkinesis-asl \
+ clean install
+
+ ./dev/change-version-to-2.10.sh
+
+ pushd $SPARK_REPO
+
+ # Remove any extra files generated during install
+ find . -type f |grep -v \.jar |grep -v \.pom | xargs rm
+
+ echo "Creating hash and signature files"
+ for file in $(find . -type f)
+ do
+ echo $GPG_PASSPHRASE | gpg --passphrase-fd 0 --output $file.asc --detach-sig --armour $file;
+ if [ $(command -v md5) ]; then
+ # Available on OS X; -q to keep only hash
+ md5 -q $file > $file.md5
+ else
+ # Available on Linux; cut to keep only hash
+ md5sum $file | cut -f1 -d' ' > $file.md5
+ fi
+ shasum -a 1 $file | cut -f1 -d' ' > $file.sha1
+ done
+
+ nexus_upload=$NEXUS_ROOT/deployByRepositoryId/$staged_repo_id
+ echo "Uplading files to $nexus_upload"
+ for file in $(find . -type f)
+ do
+ # strip leading ./
+ file_short=$(echo $file | sed -e "s/\.\///")
+ dest_url="$nexus_upload/org/apache/spark/$file_short"
+ echo " Uploading $file_short"
+ curl -u $ASF_USERNAME:$ASF_PASSWORD --upload-file $file_short $dest_url
+ done
+
+ echo "Closing nexus staging repository"
+ repo_request="$staged_repo_idApache Spark $GIT_TAG"
+ out=$(curl -X POST -d "$repo_request" -u $ASF_USERNAME:$ASF_PASSWORD \
+ -H "Content-Type:application/xml" -v \
+ $NEXUS_ROOT/profiles/$NEXUS_PROFILE/finish)
+ echo "Closed Nexus staging repository: $staged_repo_id"
+
+ popd
+ popd
rm -rf spark
fi
@@ -101,7 +182,13 @@ make_binary_release() {
cp -r spark spark-$RELEASE_VERSION-bin-$NAME
cd spark-$RELEASE_VERSION-bin-$NAME
- ./make-distribution.sh --name $NAME --tgz $FLAGS
+
+ # TODO There should probably be a flag to make-distribution to allow 2.11 support
+ if [[ $FLAGS == *scala-2.11* ]]; then
+ ./dev/change-version-to-2.11.sh
+ fi
+
+ ./make-distribution.sh --name $NAME --tgz $FLAGS 2>&1 | tee ../binary-release-$NAME.log
cd ..
cp spark-$RELEASE_VERSION-bin-$NAME/spark-$RELEASE_VERSION-bin-$NAME.tgz .
rm -rf spark-$RELEASE_VERSION-bin-$NAME
@@ -117,22 +204,24 @@ make_binary_release() {
spark-$RELEASE_VERSION-bin-$NAME.tgz.sha
}
-make_binary_release "hadoop1" "-Phive -Dhadoop.version=1.0.4" &
-make_binary_release "cdh4" "-Phive -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
-make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Pyarn" &
-make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Pyarn" &
+
+make_binary_release "hadoop1" "-Phive -Phive-thriftserver -Dhadoop.version=1.0.4" &
+make_binary_release "hadoop1-scala2.11" "-Phive -Dscala-2.11" &
+make_binary_release "cdh4" "-Phive -Phive-thriftserver -Dhadoop.version=2.0.0-mr1-cdh4.2.0" &
+make_binary_release "hadoop2.3" "-Phadoop-2.3 -Phive -Phive-thriftserver -Pyarn" &
+make_binary_release "hadoop2.4" "-Phadoop-2.4 -Phive -Phive-thriftserver -Pyarn" &
+make_binary_release "mapr3" "-Pmapr3 -Phive -Phive-thriftserver" &
+make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive -Phive-thriftserver" &
make_binary_release "hadoop2.4-without-hive" "-Phadoop-2.4 -Pyarn" &
-make_binary_release "mapr3" "-Pmapr3 -Phive" &
-make_binary_release "mapr4" "-Pmapr4 -Pyarn -Phive" &
wait
# Copy data
echo "Copying release tarballs"
rc_folder=spark-$RELEASE_VERSION-$RC_NAME
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_folder
+ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_folder
scp spark-* \
- $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_folder/
+ $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_folder/
# Docs
cd spark
@@ -142,12 +231,12 @@ cd docs
JAVA_HOME=$JAVA_7_HOME PRODUCTION=1 jekyll build
echo "Copying release documentation"
rc_docs_folder=${rc_folder}-docs
-ssh $USER_NAME@people.apache.org \
- mkdir /home/$USER_NAME/public_html/$rc_docs_folder
-rsync -r _site/* $USER_NAME@people.apache.org:/home/$USER_NAME/public_html/$rc_docs_folder
+ssh $ASF_USERNAME@people.apache.org \
+ mkdir /home/$ASF_USERNAME/public_html/$rc_docs_folder
+rsync -r _site/* $ASF_USERNAME@people.apache.org:/home/$ASF_USERNAME/public_html/$rc_docs_folder
echo "Release $RELEASE_VERSION completed:"
echo "Git tag:\t $GIT_TAG"
echo "Release commit:\t $release_hash"
-echo "Binary location:\t http://people.apache.org/~$USER_NAME/$rc_folder"
-echo "Doc location:\t http://people.apache.org/~$USER_NAME/$rc_docs_folder"
+echo "Binary location:\t http://people.apache.org/~$ASF_USERNAME/$rc_folder"
+echo "Doc location:\t http://people.apache.org/~$ASF_USERNAME/$rc_docs_folder"
diff --git a/dev/create-release/generate-contributors.py b/dev/create-release/generate-contributors.py
new file mode 100755
index 0000000000000..8aaa250bd7e29
--- /dev/null
+++ b/dev/create-release/generate-contributors.py
@@ -0,0 +1,244 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This script automates the process of creating release notes.
+
+import os
+import re
+import sys
+
+from releaseutils import *
+
+# You must set the following before use!
+JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
+RELEASE_TAG = os.environ.get("RELEASE_TAG", "v1.2.0-rc2")
+PREVIOUS_RELEASE_TAG = os.environ.get("PREVIOUS_RELEASE_TAG", "v1.1.0")
+
+# If the release tags are not provided, prompt the user to provide them
+while not tag_exists(RELEASE_TAG):
+ RELEASE_TAG = raw_input("Please provide a valid release tag: ")
+while not tag_exists(PREVIOUS_RELEASE_TAG):
+ print "Please specify the previous release tag."
+ PREVIOUS_RELEASE_TAG = raw_input(\
+ "For instance, if you are releasing v1.2.0, you should specify v1.1.0: ")
+
+# Gather commits found in the new tag but not in the old tag.
+# This filters commits based on both the git hash and the PR number.
+# If either is present in the old tag, then we ignore the commit.
+print "Gathering new commits between tags %s and %s" % (PREVIOUS_RELEASE_TAG, RELEASE_TAG)
+release_commits = get_commits(RELEASE_TAG)
+previous_release_commits = get_commits(PREVIOUS_RELEASE_TAG)
+previous_release_hashes = set()
+previous_release_prs = set()
+for old_commit in previous_release_commits:
+ previous_release_hashes.add(old_commit.get_hash())
+ if old_commit.get_pr_number():
+ previous_release_prs.add(old_commit.get_pr_number())
+new_commits = []
+for this_commit in release_commits:
+ this_hash = this_commit.get_hash()
+ this_pr_number = this_commit.get_pr_number()
+ if this_hash in previous_release_hashes:
+ continue
+ if this_pr_number and this_pr_number in previous_release_prs:
+ continue
+ new_commits.append(this_commit)
+if not new_commits:
+ sys.exit("There are no new commits between %s and %s!" % (PREVIOUS_RELEASE_TAG, RELEASE_TAG))
+
+# Prompt the user for confirmation that the commit range is correct
+print "\n=================================================================================="
+print "JIRA server: %s" % JIRA_API_BASE
+print "Release tag: %s" % RELEASE_TAG
+print "Previous release tag: %s" % PREVIOUS_RELEASE_TAG
+print "Number of commits in this range: %s" % len(new_commits)
+print
+def print_indented(_list):
+ for x in _list: print " %s" % x
+if yesOrNoPrompt("Show all commits?"):
+ print_indented(new_commits)
+print "==================================================================================\n"
+if not yesOrNoPrompt("Does this look correct?"):
+ sys.exit("Ok, exiting")
+
+# Filter out special commits
+releases = []
+maintenance = []
+reverts = []
+nojiras = []
+filtered_commits = []
+def is_release(commit_title):
+ return re.findall("\[release\]", commit_title.lower()) or\
+ "preparing spark release" in commit_title.lower() or\
+ "preparing development version" in commit_title.lower() or\
+ "CHANGES.txt" in commit_title
+def is_maintenance(commit_title):
+ return "maintenance" in commit_title.lower() or\
+ "manually close" in commit_title.lower()
+def has_no_jira(commit_title):
+ return not re.findall("SPARK-[0-9]+", commit_title.upper())
+def is_revert(commit_title):
+ return "revert" in commit_title.lower()
+def is_docs(commit_title):
+ return re.findall("docs*", commit_title.lower()) or\
+ "programming guide" in commit_title.lower()
+for c in new_commits:
+ t = c.get_title()
+ if not t: continue
+ elif is_release(t): releases.append(c)
+ elif is_maintenance(t): maintenance.append(c)
+ elif is_revert(t): reverts.append(c)
+ elif is_docs(t): filtered_commits.append(c) # docs may not have JIRA numbers
+ elif has_no_jira(t): nojiras.append(c)
+ else: filtered_commits.append(c)
+
+# Warn against ignored commits
+if releases or maintenance or reverts or nojiras:
+ print "\n=================================================================================="
+ if releases: print "Found %d release commits" % len(releases)
+ if maintenance: print "Found %d maintenance commits" % len(maintenance)
+ if reverts: print "Found %d revert commits" % len(reverts)
+ if nojiras: print "Found %d commits with no JIRA" % len(nojiras)
+ print "* Warning: these commits will be ignored.\n"
+ if yesOrNoPrompt("Show ignored commits?"):
+ if releases: print "Release (%d)" % len(releases); print_indented(releases)
+ if maintenance: print "Maintenance (%d)" % len(maintenance); print_indented(maintenance)
+ if reverts: print "Revert (%d)" % len(reverts); print_indented(reverts)
+ if nojiras: print "No JIRA (%d)" % len(nojiras); print_indented(nojiras)
+ print "==================== Warning: the above commits will be ignored ==================\n"
+prompt_msg = "%d commits left to process after filtering. Ok to proceed?" % len(filtered_commits)
+if not yesOrNoPrompt(prompt_msg):
+ sys.exit("Ok, exiting.")
+
+# Keep track of warnings to tell the user at the end
+warnings = []
+
+# Mapping from the invalid author name to its associated JIRA issues
+# E.g. andrewor14 -> set("SPARK-2413", "SPARK-3551", "SPARK-3471")
+invalid_authors = {}
+
+# Populate a map that groups issues and components by author
+# It takes the form: Author name -> { Contribution type -> Spark components }
+# For instance,
+# {
+# 'Andrew Or': {
+# 'bug fixes': ['windows', 'core', 'web ui'],
+# 'improvements': ['core']
+# },
+# 'Tathagata Das' : {
+# 'bug fixes': ['streaming']
+# 'new feature': ['streaming']
+# }
+# }
+#
+author_info = {}
+jira_options = { "server": JIRA_API_BASE }
+jira_client = JIRA(options = jira_options)
+print "\n=========================== Compiling contributor list ==========================="
+for commit in filtered_commits:
+ _hash = commit.get_hash()
+ title = commit.get_title()
+ issues = re.findall("SPARK-[0-9]+", title.upper())
+ author = commit.get_author()
+ date = get_date(_hash)
+ # If the author name is invalid, keep track of it along
+ # with all associated issues so we can translate it later
+ if is_valid_author(author):
+ author = capitalize_author(author)
+ else:
+ if author not in invalid_authors:
+ invalid_authors[author] = set()
+ for issue in issues:
+ invalid_authors[author].add(issue)
+ # Parse components from the commit title, if any
+ commit_components = find_components(title, _hash)
+ # Populate or merge an issue into author_info[author]
+ def populate(issue_type, components):
+ components = components or [CORE_COMPONENT] # assume core if no components provided
+ if author not in author_info:
+ author_info[author] = {}
+ if issue_type not in author_info[author]:
+ author_info[author][issue_type] = set()
+ for component in components:
+ author_info[author][issue_type].add(component)
+ # Find issues and components associated with this commit
+ for issue in issues:
+ jira_issue = jira_client.issue(issue)
+ jira_type = jira_issue.fields.issuetype.name
+ jira_type = translate_issue_type(jira_type, issue, warnings)
+ jira_components = [translate_component(c.name, _hash, warnings)\
+ for c in jira_issue.fields.components]
+ all_components = set(jira_components + commit_components)
+ populate(jira_type, all_components)
+ # For docs without an associated JIRA, manually add it ourselves
+ if is_docs(title) and not issues:
+ populate("documentation", commit_components)
+ print " Processed commit %s authored by %s on %s" % (_hash, author, date)
+print "==================================================================================\n"
+
+# Write to contributors file ordered by author names
+# Each line takes the format " * Author name -- semi-colon delimited contributions"
+# e.g. * Andrew Or -- Bug fixes in Windows, Core, and Web UI; improvements in Core
+# e.g. * Tathagata Das -- Bug fixes and new features in Streaming
+contributors_file = open(contributors_file_name, "w")
+authors = author_info.keys()
+authors.sort()
+for author in authors:
+ contribution = ""
+ components = set()
+ issue_types = set()
+ for issue_type, comps in author_info[author].items():
+ components.update(comps)
+ issue_types.add(issue_type)
+ # If there is only one component, mention it only once
+ # e.g. Bug fixes, improvements in MLlib
+ if len(components) == 1:
+ contribution = "%s in %s" % (nice_join(issue_types), next(iter(components)))
+ # Otherwise, group contributions by issue types instead of modules
+ # e.g. Bug fixes in MLlib, Core, and Streaming; documentation in YARN
+ else:
+ contributions = ["%s in %s" % (issue_type, nice_join(comps)) \
+ for issue_type, comps in author_info[author].items()]
+ contribution = "; ".join(contributions)
+ # Do not use python's capitalize() on the whole string to preserve case
+ assert contribution
+ contribution = contribution[0].capitalize() + contribution[1:]
+ # If the author name is invalid, use an intermediate format that
+ # can be translated through translate-contributors.py later
+ # E.g. andrewor14/SPARK-3425/SPARK-1157/SPARK-6672
+ if author in invalid_authors and invalid_authors[author]:
+ author = author + "/" + "/".join(invalid_authors[author])
+ line = " * %s -- %s" % (author, contribution)
+ contributors_file.write(line + "\n")
+contributors_file.close()
+print "Contributors list is successfully written to %s!" % contributors_file_name
+
+# Prompt the user to translate author names if necessary
+if invalid_authors:
+ warnings.append("Found the following invalid authors:")
+ for a in invalid_authors:
+ warnings.append("\t%s" % a)
+ warnings.append("Please run './translate-contributors.py' to translate them.")
+
+# Log any warnings encountered in the process
+if warnings:
+ print "\n============ Warnings encountered while creating the contributor list ============"
+ for w in warnings: print w
+ print "Please correct these in the final contributors list at %s." % contributors_file_name
+ print "==================================================================================\n"
+
diff --git a/dev/create-release/known_translations b/dev/create-release/known_translations
new file mode 100644
index 0000000000000..b74e4ee8a330b
--- /dev/null
+++ b/dev/create-release/known_translations
@@ -0,0 +1,59 @@
+# This is a mapping of names to be translated through translate-contributors.py
+# The format expected on each line should be: -
+CodingCat - Nan Zhu
+CrazyJvm - Chao Chen
+EugenCepoi - Eugen Cepoi
+GraceH - Jie Huang
+JerryLead - Lijie Xu
+Leolh - Liu Hao
+Lewuathe - Kai Sasaki
+RongGu - Rong Gu
+Shiti - Shiti Saxena
+Victsm - Min Shen
+WangTaoTheTonic - Wang Tao
+XuTingjun - Tingjun Xu
+YanTangZhai - Yantang Zhai
+alexdebrie - Alex DeBrie
+alokito - Alok Saldanha
+anantasty - Anant Asthana
+andrewor14 - Andrew Or
+aniketbhatnagar - Aniket Bhatnagar
+arahuja - Arun Ahuja
+brkyvz - Burak Yavuz
+chesterxgchen - Chester Chen
+chiragaggarwal - Chirag Aggarwal
+chouqin - Qiping Li
+cocoatomo - Tomohiko K.
+coderfi - Fairiz Azizi
+coderxiang - Shuo Xiang
+davies - Davies Liu
+epahomov - Egor Pahomov
+falaki - Hossein Falaki
+freeman-lab - Jeremy Freeman
+industrial-sloth - Jascha Swisher
+jackylk - Jacky Li
+jayunit100 - Jay Vyas
+jerryshao - Saisai Shao
+jkbradley - Joseph Bradley
+lianhuiwang - Lianhui Wang
+lirui-intel - Rui Li
+luluorta - Lu Lu
+luogankun - Gankun Luo
+maji2014 - Derek Ma
+mccheah - Matthew Cheah
+mengxr - Xiangrui Meng
+nartz - Nathan Artz
+odedz - Oded Zimerman
+ravipesala - Ravindra Pesala
+roxchkplusony - Victor Tso
+scwf - Wang Fei
+shimingfei - Shiming Fei
+surq - Surong Quan
+suyanNone - Su Yan
+tedyu - Ted Yu
+tigerquoll - Dale Richardson
+wangxiaojing - Xiaojing Wang
+watermen - Yadong Qi
+witgo - Guoqiang Li
+xinyunh - Xinyun Huang
+zsxwing - Shixiong Zhu
diff --git a/dev/create-release/releaseutils.py b/dev/create-release/releaseutils.py
new file mode 100755
index 0000000000000..26221b270394e
--- /dev/null
+++ b/dev/create-release/releaseutils.py
@@ -0,0 +1,256 @@
+#!/usr/bin/env python
+
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+# This file contains helper methods used in creating a release.
+
+import re
+import sys
+from subprocess import Popen, PIPE
+
+try:
+ from jira.client import JIRA
+ from jira.exceptions import JIRAError
+except ImportError:
+ print "This tool requires the jira-python library"
+ print "Install using 'sudo pip install jira-python'"
+ sys.exit(-1)
+
+try:
+ from github import Github
+ from github import GithubException
+except ImportError:
+ print "This tool requires the PyGithub library"
+ print "Install using 'sudo pip install PyGithub'"
+ sys.exit(-1)
+
+try:
+ import unidecode
+except ImportError:
+ print "This tool requires the unidecode library to decode obscure github usernames"
+ print "Install using 'sudo pip install unidecode'"
+ sys.exit(-1)
+
+# Contributors list file name
+contributors_file_name = "contributors.txt"
+
+# Prompt the user to answer yes or no until they do so
+def yesOrNoPrompt(msg):
+ response = raw_input("%s [y/n]: " % msg)
+ while response != "y" and response != "n":
+ return yesOrNoPrompt(msg)
+ return response == "y"
+
+# Utility functions run git commands (written with Git 1.8.5)
+def run_cmd(cmd): return Popen(cmd, stdout=PIPE).communicate()[0]
+def run_cmd_error(cmd): return Popen(cmd, stdout=PIPE, stderr=PIPE).communicate()[1]
+def get_date(commit_hash):
+ return run_cmd(["git", "show", "--quiet", "--pretty=format:%cd", commit_hash])
+def tag_exists(tag):
+ stderr = run_cmd_error(["git", "show", tag])
+ return "error" not in stderr
+
+# A type-safe representation of a commit
+class Commit:
+ def __init__(self, _hash, author, title, pr_number = None):
+ self._hash = _hash
+ self.author = author
+ self.title = title
+ self.pr_number = pr_number
+ def get_hash(self): return self._hash
+ def get_author(self): return self.author
+ def get_title(self): return self.title
+ def get_pr_number(self): return self.pr_number
+ def __str__(self):
+ closes_pr = "(Closes #%s)" % self.pr_number if self.pr_number else ""
+ return "%s %s %s %s" % (self._hash, self.author, self.title, closes_pr)
+
+# Return all commits that belong to the specified tag.
+#
+# Under the hood, this runs a `git log` on that tag and parses the fields
+# from the command output to construct a list of Commit objects. Note that
+# because certain fields reside in the commit description and cannot be parsed
+# through the Github API itself, we need to do some intelligent regex parsing
+# to extract those fields.
+#
+# This is written using Git 1.8.5.
+def get_commits(tag):
+ commit_start_marker = "|=== COMMIT START MARKER ===|"
+ commit_end_marker = "|=== COMMIT END MARKER ===|"
+ field_end_marker = "|=== COMMIT FIELD END MARKER ===|"
+ log_format =\
+ commit_start_marker + "%h" +\
+ field_end_marker + "%an" +\
+ field_end_marker + "%s" +\
+ commit_end_marker + "%b"
+ output = run_cmd(["git", "log", "--quiet", "--pretty=format:" + log_format, tag])
+ commits = []
+ raw_commits = [c for c in output.split(commit_start_marker) if c]
+ for commit in raw_commits:
+ if commit.count(commit_end_marker) != 1:
+ print "Commit end marker not found in commit: "
+ for line in commit.split("\n"): print line
+ sys.exit(1)
+ # Separate commit digest from the body
+ # From the digest we extract the hash, author and the title
+ # From the body, we extract the PR number and the github username
+ [commit_digest, commit_body] = commit.split(commit_end_marker)
+ if commit_digest.count(field_end_marker) != 2:
+ sys.exit("Unexpected format in commit: %s" % commit_digest)
+ [_hash, author, title] = commit_digest.split(field_end_marker)
+ # The PR number and github username is in the commit message
+ # itself and cannot be accessed through any Github API
+ pr_number = None
+ match = re.search("Closes #([0-9]+) from ([^/\\s]+)/", commit_body)
+ if match:
+ [pr_number, github_username] = match.groups()
+ # If the author name is not valid, use the github
+ # username so we can translate it properly later
+ if not is_valid_author(author):
+ author = github_username
+ # Guard against special characters
+ author = unidecode.unidecode(unicode(author, "UTF-8")).strip()
+ commit = Commit(_hash, author, title, pr_number)
+ commits.append(commit)
+ return commits
+
+# Maintain a mapping for translating issue types to contributions in the release notes
+# This serves an additional function of warning the user against unknown issue types
+# Note: This list is partially derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/issuetypes
+# Keep these in lower case
+known_issue_types = {
+ "bug": "bug fixes",
+ "build": "build fixes",
+ "dependency upgrade": "build fixes",
+ "improvement": "improvements",
+ "new feature": "new features",
+ "documentation": "documentation",
+ "test": "test",
+ "task": "improvement",
+ "sub-task": "improvement"
+}
+
+# Maintain a mapping for translating component names when creating the release notes
+# This serves an additional function of warning the user against unknown components
+# Note: This list is largely derived from this link:
+# https://issues.apache.org/jira/plugins/servlet/project-config/SPARK/components
+CORE_COMPONENT = "Core"
+known_components = {
+ "block manager": CORE_COMPONENT,
+ "build": CORE_COMPONENT,
+ "deploy": CORE_COMPONENT,
+ "documentation": CORE_COMPONENT,
+ "ec2": "EC2",
+ "examples": CORE_COMPONENT,
+ "graphx": "GraphX",
+ "input/output": CORE_COMPONENT,
+ "java api": "Java API",
+ "mesos": "Mesos",
+ "ml": "MLlib",
+ "mllib": "MLlib",
+ "project infra": "Project Infra",
+ "pyspark": "PySpark",
+ "shuffle": "Shuffle",
+ "spark core": CORE_COMPONENT,
+ "spark shell": CORE_COMPONENT,
+ "sql": "SQL",
+ "streaming": "Streaming",
+ "web ui": "Web UI",
+ "windows": "Windows",
+ "yarn": "YARN"
+}
+
+# Translate issue types using a format appropriate for writing contributions
+# If an unknown issue type is encountered, warn the user
+def translate_issue_type(issue_type, issue_id, warnings):
+ issue_type = issue_type.lower()
+ if issue_type in known_issue_types:
+ return known_issue_types[issue_type]
+ else:
+ warnings.append("Unknown issue type \"%s\" (see %s)" % (issue_type, issue_id))
+ return issue_type
+
+# Translate component names using a format appropriate for writing contributions
+# If an unknown component is encountered, warn the user
+def translate_component(component, commit_hash, warnings):
+ component = component.lower()
+ if component in known_components:
+ return known_components[component]
+ else:
+ warnings.append("Unknown component \"%s\" (see %s)" % (component, commit_hash))
+ return component
+
+# Parse components in the commit message
+# The returned components are already filtered and translated
+def find_components(commit, commit_hash):
+ components = re.findall("\[\w*\]", commit.lower())
+ components = [translate_component(c, commit_hash)\
+ for c in components if c in known_components]
+ return components
+
+# Join a list of strings in a human-readable manner
+# e.g. ["Juice"] -> "Juice"
+# e.g. ["Juice", "baby"] -> "Juice and baby"
+# e.g. ["Juice", "baby", "moon"] -> "Juice, baby, and moon"
+def nice_join(str_list):
+ str_list = list(str_list) # sometimes it's a set
+ if not str_list:
+ return ""
+ elif len(str_list) == 1:
+ return next(iter(str_list))
+ elif len(str_list) == 2:
+ return " and ".join(str_list)
+ else:
+ return ", ".join(str_list[:-1]) + ", and " + str_list[-1]
+
+# Return the full name of the specified user on Github
+# If the user doesn't exist, return None
+def get_github_name(author, github_client):
+ if github_client:
+ try:
+ return github_client.get_user(author).name
+ except GithubException as e:
+ # If this is not a "not found" exception
+ if e.status != 404:
+ raise e
+ return None
+
+# Return the full name of the specified user on JIRA
+# If the user doesn't exist, return None
+def get_jira_name(author, jira_client):
+ if jira_client:
+ try:
+ return jira_client.user(author).displayName
+ except JIRAError as e:
+ # If this is not a "not found" exception
+ if e.status_code != 404:
+ raise e
+ return None
+
+# Return whether the given name is in the form
+def is_valid_author(author):
+ if not author: return False
+ return " " in author and not re.findall("[0-9]", author)
+
+# Capitalize the first letter of each word in the given author name
+def capitalize_author(author):
+ if not author: return None
+ words = author.split(" ")
+ words = [w[0].capitalize() + w[1:] for w in words if w]
+ return " ".join(words)
+
diff --git a/dev/create-release/translate-contributors.py b/dev/create-release/translate-contributors.py
new file mode 100755
index 0000000000000..86fa02d87b9a0
--- /dev/null
+++ b/dev/create-release/translate-contributors.py
@@ -0,0 +1,253 @@
+#!/usr/bin/env python
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements. See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+# This script translates invalid authors in the contributors list generated
+# by generate-contributors.py. When the script encounters an author name that
+# is considered invalid, it searches Github and JIRA in an attempt to search
+# for replacements. This tool runs in two modes:
+#
+# (1) Interactive mode: For each invalid author name, this script presents
+# all candidate replacements to the user and awaits user response. In this
+# mode, the user may also input a custom name. This is the default.
+#
+# (2) Non-interactive mode: For each invalid author name, this script replaces
+# the name with the first valid candidate it can find. If there is none, it
+# uses the original name. This can be enabled through the --non-interactive flag.
+
+import os
+import sys
+
+from releaseutils import *
+
+# You must set the following before use!
+JIRA_API_BASE = os.environ.get("JIRA_API_BASE", "https://issues.apache.org/jira")
+JIRA_USERNAME = os.environ.get("JIRA_USERNAME", None)
+JIRA_PASSWORD = os.environ.get("JIRA_PASSWORD", None)
+GITHUB_API_TOKEN = os.environ.get("GITHUB_API_TOKEN", None)
+if not JIRA_USERNAME or not JIRA_PASSWORD:
+ sys.exit("Both JIRA_USERNAME and JIRA_PASSWORD must be set")
+if not GITHUB_API_TOKEN:
+ sys.exit("GITHUB_API_TOKEN must be set")
+
+# Write new contributors list to .final
+if not os.path.isfile(contributors_file_name):
+ print "Contributors file %s does not exist!" % contributors_file_name
+ print "Have you run ./generate-contributors.py yet?"
+ sys.exit(1)
+contributors_file = open(contributors_file_name, "r")
+warnings = []
+
+# In non-interactive mode, this script will choose the first replacement that is valid
+INTERACTIVE_MODE = True
+if len(sys.argv) > 1:
+ options = set(sys.argv[1:])
+ if "--non-interactive" in options:
+ INTERACTIVE_MODE = False
+if INTERACTIVE_MODE:
+ print "Running in interactive mode. To disable this, provide the --non-interactive flag."
+
+# Setup Github and JIRA clients
+jira_options = { "server": JIRA_API_BASE }
+jira_client = JIRA(options = jira_options, basic_auth = (JIRA_USERNAME, JIRA_PASSWORD))
+github_client = Github(GITHUB_API_TOKEN)
+
+# Load known author translations that are cached locally
+known_translations = {}
+known_translations_file_name = "known_translations"
+known_translations_file = open(known_translations_file_name, "r")
+for line in known_translations_file:
+ if line.startswith("#"): continue
+ [old_name, new_name] = line.strip("\n").split(" - ")
+ known_translations[old_name] = new_name
+known_translations_file.close()
+
+# Open again in case the user adds new mappings
+known_translations_file = open(known_translations_file_name, "a")
+
+# Generate candidates for the given author. This should only be called if the given author
+# name does not represent a full name as this operation is somewhat expensive. Under the
+# hood, it makes several calls to the Github and JIRA API servers to find the candidates.
+#
+# This returns a list of (candidate name, source) 2-tuples. E.g.
+# [
+# (NOT_FOUND, "No full name found for Github user andrewor14"),
+# ("Andrew Or", "Full name of JIRA user andrewor14"),
+# ("Andrew Orso", "Full name of SPARK-1444 assignee andrewor14"),
+# ("Andrew Ordall", "Full name of SPARK-1663 assignee andrewor14"),
+# (NOT_FOUND, "No assignee found for SPARK-1763")
+# ]
+NOT_FOUND = "Not found"
+def generate_candidates(author, issues):
+ candidates = []
+ # First check for full name of Github user
+ github_name = get_github_name(author, github_client)
+ if github_name:
+ candidates.append((github_name, "Full name of Github user %s" % author))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for Github user %s" % author))
+ # Then do the same for JIRA user
+ jira_name = get_jira_name(author, jira_client)
+ if jira_name:
+ candidates.append((jira_name, "Full name of JIRA user %s" % author))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for JIRA user %s" % author))
+ # Then do the same for the assignee of each of the associated JIRAs
+ # Note that a given issue may not have an assignee, or the assignee may not have a full name
+ for issue in issues:
+ try:
+ jira_issue = jira_client.issue(issue)
+ except JIRAError as e:
+ # Do not exit just because an issue is not found!
+ if e.status_code == 404:
+ warnings.append("Issue %s not found!" % issue)
+ continue
+ raise e
+ jira_assignee = jira_issue.fields.assignee
+ if jira_assignee:
+ user_name = jira_assignee.name
+ display_name = jira_assignee.displayName
+ if display_name:
+ candidates.append((display_name, "Full name of %s assignee %s" % (issue, user_name)))
+ else:
+ candidates.append((NOT_FOUND, "No full name found for %s assignee %" % (issue, user_name)))
+ else:
+ candidates.append((NOT_FOUND, "No assignee found for %s" % issue))
+ # Guard against special characters in candidate names
+ # Note that the candidate name may already be in unicode (JIRA returns this)
+ for i, (candidate, source) in enumerate(candidates):
+ try:
+ candidate = unicode(candidate, "UTF-8")
+ except TypeError:
+ # already in unicode
+ pass
+ candidate = unidecode.unidecode(candidate).strip()
+ candidates[i] = (candidate, source)
+ return candidates
+
+# Translate each invalid author by searching for possible candidates from Github and JIRA
+# In interactive mode, this script presents the user with a list of choices and have the user
+# select from this list. Additionally, the user may also choose to enter a custom name.
+# In non-interactive mode, this script picks the first valid author name from the candidates
+# If no such name exists, the original name is used (without the JIRA numbers).
+print "\n========================== Translating contributor list =========================="
+lines = contributors_file.readlines()
+contributions = []
+for i, line in enumerate(lines):
+ temp_author = line.strip(" * ").split(" -- ")[0]
+ print "Processing author %s (%d/%d)" % (temp_author, i + 1, len(lines))
+ if not temp_author:
+ error_msg = " ERROR: Expected the following format \" * -- \"\n"
+ error_msg += " ERROR: Actual = %s" % line
+ print error_msg
+ warnings.append(error_msg)
+ contributions.append(line)
+ continue
+ author = temp_author.split("/")[0]
+ # Use the local copy of known translations where possible
+ if author in known_translations:
+ line = line.replace(temp_author, known_translations[author])
+ elif not is_valid_author(author):
+ new_author = author
+ issues = temp_author.split("/")[1:]
+ candidates = generate_candidates(author, issues)
+ # Print out potential replacement candidates along with the sources, e.g.
+ # [X] No full name found for Github user andrewor14
+ # [X] No assignee found for SPARK-1763
+ # [0] Andrew Or - Full name of JIRA user andrewor14
+ # [1] Andrew Orso - Full name of SPARK-1444 assignee andrewor14
+ # [2] Andrew Ordall - Full name of SPARK-1663 assignee andrewor14
+ # [3] andrewor14 - Raw Github username
+ # [4] Custom
+ candidate_names = []
+ bad_prompts = [] # Prompts that can't actually be selected; print these first.
+ good_prompts = [] # Prompts that contain valid choices
+ for candidate, source in candidates:
+ if candidate == NOT_FOUND:
+ bad_prompts.append(" [X] %s" % source)
+ else:
+ index = len(candidate_names)
+ candidate_names.append(candidate)
+ good_prompts.append(" [%d] %s - %s" % (index, candidate, source))
+ raw_index = len(candidate_names)
+ custom_index = len(candidate_names) + 1
+ for p in bad_prompts: print p
+ if bad_prompts: print " ---"
+ for p in good_prompts: print p
+ # In interactive mode, additionally provide "custom" option and await user response
+ if INTERACTIVE_MODE:
+ print " [%d] %s - Raw Github username" % (raw_index, author)
+ print " [%d] Custom" % custom_index
+ response = raw_input(" Your choice: ")
+ last_index = custom_index
+ while not response.isdigit() or int(response) > last_index:
+ response = raw_input(" Please enter an integer between 0 and %d: " % last_index)
+ response = int(response)
+ if response == custom_index:
+ new_author = raw_input(" Please type a custom name for this author: ")
+ elif response != raw_index:
+ new_author = candidate_names[response]
+ # In non-interactive mode, just pick the first candidate
+ else:
+ valid_candidate_names = [name for name, _ in candidates\
+ if is_valid_author(name) and name != NOT_FOUND]
+ if valid_candidate_names:
+ new_author = valid_candidate_names[0]
+ # Finally, capitalize the author and replace the original one with it
+ # If the final replacement is still invalid, log a warning
+ if is_valid_author(new_author):
+ new_author = capitalize_author(new_author)
+ else:
+ warnings.append("Unable to find a valid name %s for author %s" % (author, temp_author))
+ print " * Replacing %s with %s" % (author, new_author)
+ # If we are in interactive mode, prompt the user whether we want to remember this new mapping
+ if INTERACTIVE_MODE and\
+ author not in known_translations and\
+ yesOrNoPrompt(" Add mapping %s -> %s to known translations file?" % (author, new_author)):
+ known_translations_file.write("%s - %s\n" % (author, new_author))
+ known_translations_file.flush()
+ line = line.replace(temp_author, author)
+ contributions.append(line)
+print "==================================================================================\n"
+contributors_file.close()
+known_translations_file.close()
+
+# Sort the contributions before writing them to the new file.
+# Additionally, check if there are any duplicate author rows.
+# This could happen if the same user has both a valid full
+# name (e.g. Andrew Or) and an invalid one (andrewor14).
+# If so, warn the user about this at the end.
+contributions.sort()
+all_authors = set()
+new_contributors_file_name = contributors_file_name + ".final"
+new_contributors_file = open(new_contributors_file_name, "w")
+for line in contributions:
+ author = line.strip(" * ").split(" -- ")[0]
+ if author in all_authors:
+ warnings.append("Detected duplicate author name %s. Please merge these manually." % author)
+ all_authors.add(author)
+ new_contributors_file.write(line)
+new_contributors_file.close()
+
+print "Translated contributors list successfully written to %s!" % new_contributors_file_name
+
+# Log any warnings encountered in the process
+if warnings:
+ print "\n========== Warnings encountered while translating the contributor list ==========="
+ for w in warnings: print w
+ print "Please manually correct these in the final contributors list at %s." % new_contributors_file_name
+ print "==================================================================================\n"
+
diff --git a/dev/run-tests b/dev/run-tests
index 0e9eefa76a18b..9192cb7e169f3 100755
--- a/dev/run-tests
+++ b/dev/run-tests
@@ -139,27 +139,28 @@ echo "========================================================================="
CURRENT_BLOCK=$BLOCK_BUILD
{
- # We always build with Hive because the PySpark Spark SQL tests need it.
- BUILD_MVN_PROFILE_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-0.12.0"
-
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
- #+ (either resolution or compilation) prompts the user for input either q, r, etc
- #+ to quit or retry. This echo is there to make it not block.
+ # (either resolution or compilation) prompts the user for input either q, r, etc
+ # to quit or retry. This echo is there to make it not block.
# NOTE: Do not quote $BUILD_MVN_PROFILE_ARGS or else it will be interpreted as a
- #+ single argument!
+ # single argument!
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
- # First build with 0.12 to ensure patches do not break the hive 12 build
- echo "[info] Compile with hive 0.12"
+ # First build with Hive 0.12.0 to ensure patches do not break the Hive 0.12.0 build
+ HIVE_12_BUILD_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver -Phive-0.12.0"
+ echo "[info] Compile with Hive 0.12.0"
echo -e "q\n" \
- | sbt/sbt $BUILD_MVN_PROFILE_ARGS clean hive/compile hive-thriftserver/compile \
+ | sbt/sbt $HIVE_12_BUILD_ARGS clean hive/compile hive-thriftserver/compile \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
- # Then build with default version(0.13.1) because tests are based on this version
- echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS -Phive"
+ # Then build with default Hive version (0.13.1) because tests are based on this version
+ echo "[info] Compile with Hive 0.13.1"
+ rm -rf lib_managed
+ echo "[info] Building Spark with these arguments: $SBT_MAVEN_PROFILES_ARGS"\
+ " -Phive -Phive-thriftserver"
echo -e "q\n" \
- | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive package assembly/assembly \
+ | sbt/sbt $SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver package assembly/assembly \
| grep -v -e "info.*Resolving" -e "warn.*Merging" -e "info.*Including"
}
@@ -174,13 +175,13 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
# If the Spark SQL tests are enabled, run the tests with the Hive profiles enabled.
# This must be a single argument, as it is.
if [ -n "$_RUN_SQL_TESTS" ]; then
- SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive"
+ SBT_MAVEN_PROFILES_ARGS="$SBT_MAVEN_PROFILES_ARGS -Phive -Phive-thriftserver"
fi
if [ -n "$_SQL_TESTS_ONLY" ]; then
# This must be an array of individual arguments. Otherwise, having one long string
- #+ will be interpreted as a single test, which doesn't work.
- SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test")
+ # will be interpreted as a single test, which doesn't work.
+ SBT_MAVEN_TEST_ARGS=("catalyst/test" "sql/test" "hive/test" "mllib/test")
else
SBT_MAVEN_TEST_ARGS=("test")
fi
@@ -188,11 +189,11 @@ CURRENT_BLOCK=$BLOCK_SPARK_UNIT_TESTS
echo "[info] Running Spark tests with these arguments: $SBT_MAVEN_PROFILES_ARGS ${SBT_MAVEN_TEST_ARGS[@]}"
# NOTE: echo "q" is needed because sbt on encountering a build file with failure
- #+ (either resolution or compilation) prompts the user for input either q, r, etc
- #+ to quit or retry. This echo is there to make it not block.
+ # (either resolution or compilation) prompts the user for input either q, r, etc
+ # to quit or retry. This echo is there to make it not block.
# NOTE: Do not quote $SBT_MAVEN_PROFILES_ARGS or else it will be interpreted as a
- #+ single argument!
- #+ "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
+ # single argument!
+ # "${SBT_MAVEN_TEST_ARGS[@]}" is cool because it's an array.
# QUESTION: Why doesn't 'yes "q"' work?
# QUESTION: Why doesn't 'grep -v -e "^\[info\] Resolving"' work?
echo -e "q\n" \
@@ -211,7 +212,7 @@ CURRENT_BLOCK=$BLOCK_PYSPARK_UNIT_TESTS
echo ""
echo "========================================================================="
-echo "Detecting binary incompatibilites with MiMa"
+echo "Detecting binary incompatibilities with MiMa"
echo "========================================================================="
CURRENT_BLOCK=$BLOCK_MIMA
diff --git a/dev/scalastyle b/dev/scalastyle
index ed1b6b730af6e..c3c6012e74ffa 100755
--- a/dev/scalastyle
+++ b/dev/scalastyle
@@ -17,7 +17,7 @@
# limitations under the License.
#
-echo -e "q\n" | sbt/sbt -Phive scalastyle > scalastyle.txt
+echo -e "q\n" | sbt/sbt -Phive -Phive-thriftserver scalastyle > scalastyle.txt
# Check style with YARN alpha built too
echo -e "q\n" | sbt/sbt -Pyarn-alpha -Phadoop-0.23 -Dhadoop.version=0.23.9 yarn-alpha/scalastyle \
>> scalastyle.txt
diff --git a/docs/README.md b/docs/README.md
index d2d58e435d4c4..119484038083f 100644
--- a/docs/README.md
+++ b/docs/README.md
@@ -43,7 +43,7 @@ You can modify the default Jekyll build as follows:
## Pygments
We also use pygments (http://pygments.org) for syntax highlighting in documentation markdown pages,
-so you will also need to install that (it requires Python) by running `sudo easy_install Pygments`.
+so you will also need to install that (it requires Python) by running `sudo pip install Pygments`.
To mark a block of code in your markdown to be syntax highlighted by jekyll during the compile
phase, use the following sytax:
@@ -53,6 +53,11 @@ phase, use the following sytax:
// supported languages too.
{% endhighlight %}
+## Sphinx
+
+We use Sphinx to generate Python API docs, so you will need to install it by running
+`sudo pip install sphinx`.
+
## API Docs (Scaladoc and Sphinx)
You can build just the Spark scaladoc by running `sbt/sbt doc` from the SPARK_PROJECT_ROOT directory.
diff --git a/docs/_config.yml b/docs/_config.yml
index cdea02fcffbc5..a6c176cde5a49 100644
--- a/docs/_config.yml
+++ b/docs/_config.yml
@@ -13,8 +13,8 @@ include:
# These allow the documentation to be updated with newer releases
# of Spark, Scala, and Mesos.
-SPARK_VERSION: 1.2.0-SNAPSHOT
-SPARK_VERSION_SHORT: 1.2.0
+SPARK_VERSION: 1.2.1
+SPARK_VERSION_SHORT: 1.2.1
SCALA_BINARY_VERSION: "2.10"
SCALA_VERSION: "2.10.4"
MESOS_VERSION: 0.18.1
diff --git a/docs/_layouts/global.html b/docs/_layouts/global.html
index 627ed37de4a9c..8841f7675d35e 100755
--- a/docs/_layouts/global.html
+++ b/docs/_layouts/global.html
@@ -33,7 +33,7 @@
+
+
+
+
+
org.apache.maven.plugins
@@ -1174,6 +1238,25 @@
+
+ doclint-java8-disable
+
+ [1.8,)
+
+
+
+
+
+ org.apache.maven.plugins
+ maven-javadoc-plugin
+
+ -Xdoclint:all -Xdoclint:-missing
+
+
+
+
+
+
hadoop-provided
-
- false
- org.apache.hadoop
@@ -1334,19 +1409,13 @@
- hive
-
- false
-
+ hive-thriftserversql/hive-thriftserverhive-0.12.0
-
- false
- 0.12.0-protobuf-2.50.12.0
@@ -1355,14 +1424,41 @@
hive-0.13.1
-
- false
-
- 0.13.1
+ 0.13.1a0.13.110.10.1.1
+
+
+ scala-2.10
+
+ !scala-2.11
+
+
+ 2.10.4
+ 2.10
+ ${scala.version}
+ org.scala-lang
+
+
+ external/kafka
+
+
+
+
+ scala-2.11
+
+ scala-2.11
+
+
+ 2.11.2
+ 2.11
+ 2.12
+ jline
+
+
+
diff --git a/project/MimaExcludes.scala b/project/MimaExcludes.scala
index 6a0495f8fd540..8a2a865867fc4 100644
--- a/project/MimaExcludes.scala
+++ b/project/MimaExcludes.scala
@@ -77,6 +77,18 @@ object MimaExcludes {
// SPARK-3822
ProblemFilters.exclude[IncompatibleResultTypeProblem](
"org.apache.spark.SparkContext.org$apache$spark$SparkContext$$createTaskScheduler")
+ ) ++ Seq(
+ // SPARK-1209
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.hadoop.mapreduce.SparkHadoopMapReduceUtil"),
+ ProblemFilters.exclude[MissingClassProblem](
+ "org.apache.hadoop.mapred.SparkHadoopMapRedUtil"),
+ ProblemFilters.exclude[MissingTypesProblem](
+ "org.apache.spark.rdd.PairRDDFunctions")
+ ) ++ Seq(
+ // SPARK-4062
+ ProblemFilters.exclude[MissingMethodProblem](
+ "org.apache.spark.streaming.kafka.KafkaReceiver#MessageHandler.this")
)
case v if v.startsWith("1.1") =>
diff --git a/project/SparkBuild.scala b/project/SparkBuild.scala
index 33618f5401768..49628b1f51e4c 100644
--- a/project/SparkBuild.scala
+++ b/project/SparkBuild.scala
@@ -15,6 +15,8 @@
* limitations under the License.
*/
+import java.io.File
+
import scala.util.Properties
import scala.collection.JavaConversions._
@@ -23,7 +25,7 @@ import sbt.Classpaths.publishTask
import sbt.Keys._
import sbtunidoc.Plugin.genjavadocSettings
import sbtunidoc.Plugin.UnidocKeys.unidocGenjavadocVersion
-import com.typesafe.sbt.pom.{PomBuild, SbtPomKeys}
+import com.typesafe.sbt.pom.{loadEffectivePom, PomBuild, SbtPomKeys}
import net.virtualvoid.sbt.graph.Plugin.graphSettings
object BuildCommons {
@@ -31,19 +33,19 @@ object BuildCommons {
private val buildLocation = file(".").getAbsoluteFile.getParentFile
val allProjects@Seq(bagel, catalyst, core, graphx, hive, hiveThriftServer, mllib, repl,
- sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
- streamingMqtt, streamingTwitter, streamingZeromq) =
+ sql, networkCommon, networkShuffle, streaming, streamingFlumeSink, streamingFlume, streamingKafka,
+ streamingMqtt, streamingTwitter, streamingZeromq) =
Seq("bagel", "catalyst", "core", "graphx", "hive", "hive-thriftserver", "mllib", "repl",
"sql", "network-common", "network-shuffle", "streaming", "streaming-flume-sink",
"streaming-flume", "streaming-kafka", "streaming-mqtt", "streaming-twitter",
"streaming-zeromq").map(ProjectRef(buildLocation, _))
- val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests, sparkGangliaLgpl, sparkKinesisAsl) =
- Seq("yarn", "yarn-stable", "yarn-alpha", "java8-tests", "ganglia-lgpl", "kinesis-asl")
- .map(ProjectRef(buildLocation, _))
+ val optionallyEnabledProjects@Seq(yarn, yarnStable, yarnAlpha, java8Tests,
+ sparkGangliaLgpl, sparkKinesisAsl) = Seq("yarn", "yarn-stable", "yarn-alpha",
+ "java8-tests", "ganglia-lgpl", "kinesis-asl").map(ProjectRef(buildLocation, _))
- val assemblyProjects@Seq(assembly, examples) = Seq("assembly", "examples")
- .map(ProjectRef(buildLocation, _))
+ val assemblyProjects@Seq(assembly, examples, networkYarn) =
+ Seq("assembly", "examples", "network-yarn").map(ProjectRef(buildLocation, _))
val tools = ProjectRef(buildLocation, "tools")
// Root project.
@@ -68,8 +70,8 @@ object SparkBuild extends PomBuild {
profiles ++= Seq("spark-ganglia-lgpl")
}
if (Properties.envOrNone("SPARK_HIVE").isDefined) {
- println("NOTE: SPARK_HIVE is deprecated, please use -Phive flag.")
- profiles ++= Seq("hive")
+ println("NOTE: SPARK_HIVE is deprecated, please use -Phive and -Phive-thriftserver flags.")
+ profiles ++= Seq("hive", "hive-thriftserver")
}
Properties.envOrNone("SPARK_HADOOP_VERSION") match {
case Some(v) =>
@@ -91,13 +93,23 @@ object SparkBuild extends PomBuild {
profiles
}
- override val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match {
+ override val profiles = {
+ val profiles = Properties.envOrNone("SBT_MAVEN_PROFILES") match {
case None => backwardCompatibility
case Some(v) =>
if (backwardCompatibility.nonEmpty)
println("Note: We ignore environment variables, when use of profile is detected in " +
"conjunction with environment variable.")
v.split("(\\s+|,)").filterNot(_.isEmpty).map(_.trim.replaceAll("-P", "")).toSeq
+ }
+
+ if (System.getProperty("scala-2.11") == "") {
+ // To activate scala-2.11 profile, replace empty property value to non-empty value
+ // in the same way as Maven which handles -Dname as -Dname=true before executes build process.
+ // see: https://github.com/apache/maven/blob/maven-3.0.4/maven-embedder/src/main/java/org/apache/maven/cli/MavenCli.java#L1082
+ System.setProperty("scala-2.11", "true")
+ }
+ profiles
}
Properties.envOrNone("SBT_MAVEN_PROPERTIES") match {
@@ -108,6 +120,17 @@ object SparkBuild extends PomBuild {
override val userPropertiesMap = System.getProperties.toMap
+ // Handle case where hadoop.version is set via profile.
+ // Needed only because we read back this property in sbt
+ // when we create the assembly jar.
+ val pom = loadEffectivePom(new File("pom.xml"),
+ profiles = profiles,
+ userProps = userPropertiesMap)
+ if (System.getProperty("hadoop.version") == null) {
+ System.setProperty("hadoop.version",
+ pom.getProperties.get("hadoop.version").asInstanceOf[String])
+ }
+
lazy val MavenCompile = config("m2r") extend(Compile)
lazy val publishLocalBoth = TaskKey[Unit]("publish-local", "publish local for m2 and ivy")
@@ -126,7 +149,12 @@ object SparkBuild extends PomBuild {
},
publishMavenStyle in MavenCompile := true,
publishLocal in MavenCompile <<= publishTask(publishLocalConfiguration in MavenCompile, deliverLocal),
- publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn
+ publishLocalBoth <<= Seq(publishLocal in MavenCompile, publishLocal).dependOn,
+
+ javacOptions in (Compile, doc) ++= {
+ val Array(major, minor, _) = System.getProperty("java.version").split("\\.", 3)
+ if (major.toInt >= 1 && minor.toInt >= 8) Seq("-Xdoclint:all", "-Xdoclint:-missing") else Seq.empty
+ }
)
def enable(settings: Seq[Setting[_]])(projectRef: ProjectRef) = {
@@ -136,14 +164,15 @@ object SparkBuild extends PomBuild {
// Note ordering of these settings matter.
/* Enable shared settings on all projects */
- (allProjects ++ optionallyEnabledProjects ++ assemblyProjects).foreach(enable(sharedSettings))
+ (allProjects ++ optionallyEnabledProjects ++ assemblyProjects ++ Seq(spark, tools))
+ .foreach(enable(sharedSettings ++ ExludedDependencies.settings))
/* Enable tests settings for all projects except examples, assembly and tools */
(allProjects ++ optionallyEnabledProjects).foreach(enable(TestSettings.settings))
// TODO: Add Sql to mima checks
allProjects.filterNot(x => Seq(spark, sql, hive, hiveThriftServer, catalyst, repl,
- streamingFlumeSink, networkCommon, networkShuffle).contains(x)).foreach {
+ streamingFlumeSink, networkCommon, networkShuffle, networkYarn).contains(x)).foreach {
x => enable(MimaBuild.mimaSettings(sparkHome, x))(x)
}
@@ -178,6 +207,16 @@ object Flume {
lazy val settings = sbtavro.SbtAvro.avroSettings
}
+/**
+ This excludes library dependencies in sbt, which are specified in maven but are
+ not needed by sbt build.
+ */
+object ExludedDependencies {
+ lazy val settings = Seq(
+ libraryDependencies ~= { libs => libs.filterNot(_.name == "groovy-all") }
+ )
+}
+
/**
* Following project only exists to pull previous artifacts of Spark for generating
* Mima ignores. For more information see: SPARK 2071
@@ -234,6 +273,8 @@ object Hive {
lazy val settings = Seq(
javaOptions += "-XX:MaxPermSize=1g",
+ // Specially disable assertions since some Hive tests fail them
+ javaOptions in Test := (javaOptions in Test).value.filterNot(_ == "-ea"),
// Multiple queries rely on the TestHive singleton. See comments there for more details.
parallelExecution in Test := false,
// Supporting all SerDes requires us to depend on deprecated APIs, so we turn off the warnings
@@ -270,8 +311,14 @@ object Assembly {
lazy val settings = assemblySettings ++ Seq(
test in assembly := {},
- jarName in assembly <<= (version, moduleName) map { (v, mName) => mName + "-"+v + "-hadoop" +
- Option(System.getProperty("hadoop.version")).getOrElse("1.0.4") + ".jar" },
+ jarName in assembly <<= (version, moduleName) map { (v, mName) =>
+ if (mName.contains("network-yarn")) {
+ // This must match the same name used in maven (see network/yarn/pom.xml)
+ "spark-" + v + "-yarn-shuffle.jar"
+ } else {
+ mName + "-" + v + "-hadoop" + System.getProperty("hadoop.version") + ".jar"
+ }
+ },
mergeStrategy in assembly := {
case PathList("org", "datanucleus", xs @ _*) => MergeStrategy.discard
case m if m.toLowerCase.endsWith("manifest.mf") => MergeStrategy.discard
@@ -302,7 +349,7 @@ object Unidoc {
unidocProjectFilter in(ScalaUnidoc, unidoc) :=
inAnyProject -- inProjects(OldDeps.project, repl, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha),
unidocProjectFilter in(JavaUnidoc, unidoc) :=
- inAnyProject -- inProjects(OldDeps.project, repl, bagel, graphx, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha),
+ inAnyProject -- inProjects(OldDeps.project, repl, bagel, examples, tools, catalyst, streamingFlumeSink, yarn, yarnAlpha),
// Skip class names containing $ and some internal packages in Javadocs
unidocAllSources in (JavaUnidoc, unidoc) := {
@@ -330,7 +377,10 @@ object Unidoc {
"mllib.classification", "mllib.clustering", "mllib.evaluation.binary", "mllib.linalg",
"mllib.linalg.distributed", "mllib.optimization", "mllib.rdd", "mllib.recommendation",
"mllib.regression", "mllib.stat", "mllib.tree", "mllib.tree.configuration",
- "mllib.tree.impurity", "mllib.tree.model", "mllib.util"
+ "mllib.tree.impurity", "mllib.tree.model", "mllib.util",
+ "mllib.evaluation", "mllib.feature", "mllib.random", "mllib.stat.correlation",
+ "mllib.stat.test", "mllib.tree.impl", "mllib.tree.loss",
+ "ml", "ml.classification", "ml.evaluation", "ml.feature", "ml.param", "ml.tuning"
),
"-group", "Spark SQL", packageList("sql.api.java", "sql.api.java.types", "sql.hive.api.java"),
"-noqualifier", "java.lang"
@@ -348,13 +398,19 @@ object TestSettings {
javaOptions in Test += "-Dspark.testing=1",
javaOptions in Test += "-Dspark.port.maxRetries=100",
javaOptions in Test += "-Dspark.ui.enabled=false",
+ javaOptions in Test += "-Dspark.ui.showConsoleProgress=false",
+ javaOptions in Test += "-Dspark.driver.allowMultipleContexts=true",
javaOptions in Test += "-Dsun.io.serialization.extendedDebugInfo=true",
javaOptions in Test ++= System.getProperties.filter(_._1 startsWith "spark")
.map { case (k,v) => s"-D$k=$v" }.toSeq,
+ javaOptions in Test += "-ea",
javaOptions in Test ++= "-Xmx3g -XX:PermSize=128M -XX:MaxNewSize=256m -XX:MaxPermSize=1g"
.split(" ").toSeq,
+ // This places test scope jars on the classpath of executors during tests.
+ javaOptions in Test +=
+ "-Dspark.executor.extraClassPath=" + (fullClasspath in Test).value.files.
+ map(_.getAbsolutePath).mkString(":").stripSuffix(":"),
javaOptions += "-Xmx3g",
-
// Show full stack trace and duration in test cases.
testOptions in Test += Tests.Argument("-oDF"),
testOptions += Tests.Argument(TestFrameworks.JUnit, "-v", "-a"),
diff --git a/project/project/SparkPluginBuild.scala b/project/project/SparkPluginBuild.scala
index 3ef2d5451da0d..8863f272da415 100644
--- a/project/project/SparkPluginBuild.scala
+++ b/project/project/SparkPluginBuild.scala
@@ -26,7 +26,7 @@ import sbt.Keys._
object SparkPluginDef extends Build {
lazy val root = Project("plugins", file(".")) dependsOn(sparkStyle, sbtPomReader)
lazy val sparkStyle = Project("spark-style", file("spark-style"), settings = styleSettings)
- lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git")
+ lazy val sbtPomReader = uri("https://github.com/ScrapCodes/sbt-pom-reader.git#ignore_artifact_id")
// There is actually no need to publish this artifact.
def styleSettings = Defaults.defaultSettings ++ Seq (
diff --git a/python/docs/epytext.py b/python/docs/epytext.py
index 19fefbfc057a4..e884d5e6b19c7 100644
--- a/python/docs/epytext.py
+++ b/python/docs/epytext.py
@@ -1,7 +1,7 @@
import re
RULES = (
- (r"<[\w.]+>", r""),
+ (r"<(!BLANKLINE)[\w.]+>", r""),
(r"L{([\w.()]+)}", r":class:`\1`"),
(r"[LC]{(\w+\.\w+)\(\)}", r":func:`\1`"),
(r"C{([\w.()]+)}", r":class:`\1`"),
diff --git a/python/docs/pyspark.streaming.rst b/python/docs/pyspark.streaming.rst
index 5024d694b668f..f08185627d0bc 100644
--- a/python/docs/pyspark.streaming.rst
+++ b/python/docs/pyspark.streaming.rst
@@ -1,5 +1,5 @@
pyspark.streaming module
-==================
+========================
Module contents
---------------
diff --git a/python/pyspark/__init__.py b/python/pyspark/__init__.py
index e39e6514d77a1..9556e4718e585 100644
--- a/python/pyspark/__init__.py
+++ b/python/pyspark/__init__.py
@@ -37,16 +37,6 @@
"""
-# The following block allows us to import python's random instead of mllib.random for scripts in
-# mllib that depend on top level pyspark packages, which transitively depend on python's random.
-# Since Python's import logic looks for modules in the current package first, we eliminate
-# mllib.random as a candidate for C{import random} by removing the first search path, the script's
-# location, in order to force the loader to look in Python's top-level modules for C{random}.
-import sys
-s = sys.path.pop(0)
-import random
-sys.path.insert(0, s)
-
from pyspark.conf import SparkConf
from pyspark.context import SparkContext
from pyspark.rdd import RDD
diff --git a/python/pyspark/broadcast.py b/python/pyspark/broadcast.py
index f124dc6c07575..6b8a8b256a891 100644
--- a/python/pyspark/broadcast.py
+++ b/python/pyspark/broadcast.py
@@ -15,21 +15,10 @@
# limitations under the License.
#
-"""
->>> from pyspark.context import SparkContext
->>> sc = SparkContext('local', 'test')
->>> b = sc.broadcast([1, 2, 3, 4, 5])
->>> b.value
-[1, 2, 3, 4, 5]
->>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
-[1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
->>> b.unpersist()
-
->>> large_broadcast = sc.broadcast(list(range(10000)))
-"""
import os
-
-from pyspark.serializers import CompressedSerializer, PickleSerializer
+import cPickle
+import gc
+from tempfile import NamedTemporaryFile
__all__ = ['Broadcast']
@@ -49,44 +38,88 @@ def _from_id(bid):
class Broadcast(object):
"""
- A broadcast variable created with
- L{SparkContext.broadcast()}.
+ A broadcast variable created with L{SparkContext.broadcast()}.
Access its value through C{.value}.
+
+ Examples:
+
+ >>> from pyspark.context import SparkContext
+ >>> sc = SparkContext('local', 'test')
+ >>> b = sc.broadcast([1, 2, 3, 4, 5])
+ >>> b.value
+ [1, 2, 3, 4, 5]
+ >>> sc.parallelize([0, 0]).flatMap(lambda x: b.value).collect()
+ [1, 2, 3, 4, 5, 1, 2, 3, 4, 5]
+ >>> b.unpersist()
+
+ >>> large_broadcast = sc.broadcast(range(10000))
"""
- def __init__(self, bid, value, java_broadcast=None,
- pickle_registry=None, path=None):
+ def __init__(self, sc=None, value=None, pickle_registry=None, path=None):
"""
- Should not be called directly by users -- use
- L{SparkContext.broadcast()}
+ Should not be called directly by users -- use L{SparkContext.broadcast()}
instead.
"""
- self.bid = bid
- if path is None:
- self._value = value
- self._jbroadcast = java_broadcast
- self._pickle_registry = pickle_registry
- self.path = path
+ if sc is not None:
+ f = NamedTemporaryFile(delete=False, dir=sc._temp_dir)
+ self._path = self.dump(value, f)
+ self._jbroadcast = sc._jvm.PythonRDD.readBroadcastFromFile(sc._jsc, self._path)
+ self._pickle_registry = pickle_registry
+ else:
+ self._jbroadcast = None
+ self._path = path
+
+ def dump(self, value, f):
+ if isinstance(value, basestring):
+ if isinstance(value, unicode):
+ f.write('U')
+ value = value.encode('utf8')
+ else:
+ f.write('S')
+ f.write(value)
+ else:
+ f.write('P')
+ cPickle.dump(value, f, 2)
+ f.close()
+ return f.name
+
+ def load(self, path):
+ with open(path, 'rb', 1 << 20) as f:
+ flag = f.read(1)
+ data = f.read()
+ if flag == 'P':
+ # cPickle.loads() may create lots of objects, disable GC
+ # temporary for better performance
+ gc.disable()
+ try:
+ return cPickle.loads(data)
+ finally:
+ gc.enable()
+ else:
+ return data.decode('utf8') if flag == 'U' else data
@property
def value(self):
""" Return the broadcasted value
"""
- if not hasattr(self, "_value") and self.path is not None:
- ser = CompressedSerializer(PickleSerializer())
- self._value = ser.load_stream(open(self.path)).next()
+ if not hasattr(self, "_value") and self._path is not None:
+ self._value = self.load(self._path)
return self._value
def unpersist(self, blocking=False):
"""
Delete cached copies of this broadcast on the executors.
"""
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be unpersisted in driver")
self._jbroadcast.unpersist(blocking)
- os.unlink(self.path)
+ os.unlink(self._path)
def __reduce__(self):
+ if self._jbroadcast is None:
+ raise Exception("Broadcast can only be serialized in driver")
self._pickle_registry.add(self)
- return (_from_id, (self.bid, ))
+ return _from_id, (self._jbroadcast.id(),)
if __name__ == "__main__":
diff --git a/python/pyspark/context.py b/python/pyspark/context.py
index 5f8dcedb1eea2..23ff8ccf61035 100644
--- a/python/pyspark/context.py
+++ b/python/pyspark/context.py
@@ -29,7 +29,7 @@
from pyspark.files import SparkFiles
from pyspark.java_gateway import launch_gateway
from pyspark.serializers import PickleSerializer, BatchedSerializer, UTF8Deserializer, \
- PairDeserializer, CompressedSerializer, AutoBatchedSerializer
+ PairDeserializer, AutoBatchedSerializer, NoOpSerializer
from pyspark.storagelevel import StorageLevel
from pyspark.rdd import RDD
from pyspark.traceback_utils import CallSite, first_spark_call
@@ -63,7 +63,6 @@ class SparkContext(object):
_active_spark_context = None
_lock = Lock()
_python_includes = None # zip and egg files that need to be added to PYTHONPATH
- _default_batch_size_for_serialized_input = 10
def __init__(self, master=None, appName=None, sparkHome=None, pyFiles=None,
environment=None, batchSize=0, serializer=PickleSerializer(), conf=None,
@@ -115,9 +114,7 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
self._conf = conf or SparkConf(_jvm=self._jvm)
self._batchSize = batchSize # -1 represents an unlimited batch size
self._unbatched_serializer = serializer
- if batchSize == 1:
- self.serializer = self._unbatched_serializer
- elif batchSize == 0:
+ if batchSize == 0:
self.serializer = AutoBatchedSerializer(self._unbatched_serializer)
else:
self.serializer = BatchedSerializer(self._unbatched_serializer,
@@ -192,7 +189,8 @@ def _do_init(self, master, appName, sparkHome, pyFiles, environment, batchSize,
# Create a temporary directory inside spark.local.dir:
local_dir = self._jvm.org.apache.spark.util.Utils.getLocalDir(self._jsc.sc().conf())
self._temp_dir = \
- self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir).getAbsolutePath()
+ self._jvm.org.apache.spark.util.Utils.createTempDir(local_dir, "pyspark") \
+ .getAbsolutePath()
# profiling stats collected for each PythonRDD
self._profile_stats = []
@@ -232,6 +230,14 @@ def _ensure_initialized(cls, instance=None, gateway=None):
else:
SparkContext._active_spark_context = instance
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle SparkContext, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to reference SparkContext from a broadcast "
+ "variable, action, or transforamtion. SparkContext can only be used on the driver, "
+ "not in code that it run on workers. For more information, see SPARK-5063."
+ )
+
def __enter__(self):
"""
Enable 'with SparkContext(...) as sc: app(sc)' syntax.
@@ -292,12 +298,29 @@ def stop(self):
def parallelize(self, c, numSlices=None):
"""
- Distribute a local Python collection to form an RDD.
+ Distribute a local Python collection to form an RDD. Using xrange
+ is recommended if the input represents a range for performance.
- >>> sc.parallelize(range(5), 5).glom().collect()
- [[0], [1], [2], [3], [4]]
+ >>> sc.parallelize([0, 2, 3, 4, 6], 5).glom().collect()
+ [[0], [2], [3], [4], [6]]
+ >>> sc.parallelize(xrange(0, 6, 2), 5).glom().collect()
+ [[], [0], [], [2], [4]]
"""
- numSlices = numSlices or self.defaultParallelism
+ numSlices = int(numSlices) if numSlices is not None else self.defaultParallelism
+ if isinstance(c, xrange):
+ size = len(c)
+ if size == 0:
+ return self.parallelize([], numSlices)
+ step = c[1] - c[0] if size > 1 else 1
+ start0 = c[0]
+
+ def getStart(split):
+ return start0 + (split * size / numSlices) * step
+
+ def f(split, iterator):
+ return xrange(getStart(split), getStart(split + 1), step)
+
+ return self.parallelize([], numSlices).mapPartitionsWithIndex(f)
# Calling the Java parallelize() method with an ArrayList is too slow,
# because it sends O(n) Py4J commands. As an alternative, serialized
# objects are written to a file and loaded through textFile().
@@ -305,12 +328,8 @@ def parallelize(self, c, numSlices=None):
# Make sure we distribute data evenly if it's smaller than self.batchSize
if "__len__" not in dir(c):
c = list(c) # Make it a list so we can compute its length
- batchSize = min(len(c) // numSlices, self._batchSize)
- if batchSize > 1:
- serializer = BatchedSerializer(self._unbatched_serializer,
- batchSize)
- else:
- serializer = self._unbatched_serializer
+ batchSize = max(1, min(len(c) // numSlices, self._batchSize or 1024))
+ serializer = BatchedSerializer(self._unbatched_serializer, batchSize)
serializer.dump_stream(c, tempFile)
tempFile.close()
readRDDFromFile = self._jvm.PythonRDD.readRDDFromFile
@@ -328,8 +347,7 @@ def pickleFile(self, name, minPartitions=None):
[0, 1, 2, 3, 4, 5, 6, 7, 8, 9]
"""
minPartitions = minPartitions or self.defaultMinPartitions
- return RDD(self._jsc.objectFile(name, minPartitions), self,
- BatchedSerializer(PickleSerializer()))
+ return RDD(self._jsc.objectFile(name, minPartitions), self)
def textFile(self, name, minPartitions=None, use_unicode=True):
"""
@@ -396,6 +414,36 @@ def wholeTextFiles(self, path, minPartitions=None, use_unicode=True):
return RDD(self._jsc.wholeTextFiles(path, minPartitions), self,
PairDeserializer(UTF8Deserializer(use_unicode), UTF8Deserializer(use_unicode)))
+ def binaryFiles(self, path, minPartitions=None):
+ """
+ :: Experimental ::
+
+ Read a directory of binary files from HDFS, a local file system
+ (available on all nodes), or any Hadoop-supported file system URI
+ as a byte array. Each file is read as a single record and returned
+ in a key-value pair, where the key is the path of each file, the
+ value is the content of each file.
+
+ Note: Small files are preferred, large file is also allowable, but
+ may cause bad performance.
+ """
+ minPartitions = minPartitions or self.defaultMinPartitions
+ return RDD(self._jsc.binaryFiles(path, minPartitions), self,
+ PairDeserializer(UTF8Deserializer(), NoOpSerializer()))
+
+ def binaryRecords(self, path, recordLength):
+ """
+ :: Experimental ::
+
+ Load data from a flat binary file, assuming each record is a set of numbers
+ with the specified numerical format (see ByteBuffer), and the number of
+ bytes per record is constant.
+
+ :param path: Directory to the input data files
+ :param recordLength: The length at which to split the records
+ """
+ return RDD(self._jsc.binaryRecords(path, recordLength), self, NoOpSerializer())
+
def _dictToJavaMap(self, d):
jm = self._jvm.java.util.HashMap()
if not d:
@@ -405,7 +453,7 @@ def _dictToJavaMap(self, d):
return jm
def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
- valueConverter=None, minSplits=None, batchSize=None):
+ valueConverter=None, minSplits=None, batchSize=0):
"""
Read a Hadoop SequenceFile with arbitrary key and value Writable class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -427,17 +475,15 @@ def sequenceFile(self, path, keyClass=None, valueClass=None, keyConverter=None,
:param minSplits: minimum splits in dataset
(default min(2, sc.defaultParallelism))
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
minSplits = minSplits or min(self.defaultParallelism, 2)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.sequenceFile(self._jsc, path, keyClass, valueClass,
keyConverter, valueConverter, minSplits, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -458,18 +504,16 @@ def newAPIHadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConv
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read a 'new API' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -487,18 +531,16 @@ def newAPIHadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=N
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.newAPIHadoopRDD(self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class from HDFS,
a local file system (available on all nodes), or any Hadoop-supported file system URI.
@@ -519,18 +561,16 @@ def hadoopFile(self, path, inputFormatClass, keyClass, valueClass, keyConverter=
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopFile(self._jsc, path, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
- valueConverter=None, conf=None, batchSize=None):
+ valueConverter=None, conf=None, batchSize=0):
"""
Read an 'old' Hadoop InputFormat with arbitrary key and value class, from an arbitrary
Hadoop configuration, which is passed in as a Python dict.
@@ -548,15 +588,13 @@ def hadoopRDD(self, inputFormatClass, keyClass, valueClass, keyConverter=None,
:param conf: Hadoop configuration, passed in as a dict
(None by default)
:param batchSize: The number of Python objects represented as a single
- Java object. (default sc._default_batch_size_for_serialized_input)
+ Java object. (default 0, choose batchSize automatically)
"""
jconf = self._dictToJavaMap(conf)
- batchSize = max(1, batchSize or self._default_batch_size_for_serialized_input)
- ser = BatchedSerializer(PickleSerializer()) if (batchSize > 1) else PickleSerializer()
jrdd = self._jvm.PythonRDD.hadoopRDD(self._jsc, inputFormatClass, keyClass,
valueClass, keyConverter, valueConverter,
jconf, batchSize)
- return RDD(jrdd, self, ser)
+ return RDD(jrdd, self)
def _checkpointFile(self, name, input_deserializer):
jrdd = self._jsc.checkpointFile(name)
@@ -595,14 +633,7 @@ def broadcast(self, value):
object for reading it in distributed functions. The variable will
be sent to each cluster only once.
"""
- ser = CompressedSerializer(PickleSerializer())
- # pass large object by py4j is very slow and need much memory
- tempFile = NamedTemporaryFile(delete=False, dir=self._temp_dir)
- ser.dump_stream([value], tempFile)
- tempFile.close()
- jbroadcast = self._jvm.PythonRDD.readBroadcastFromFile(self._jsc, tempFile.name)
- return Broadcast(jbroadcast.id(), None, jbroadcast,
- self._pickled_broadcast_vars, tempFile.name)
+ return Broadcast(self, value, self._pickled_broadcast_vars)
def accumulator(self, value, accum_param=None):
"""
@@ -836,7 +867,7 @@ def _test():
import doctest
import tempfile
globs = globals().copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
globs['tempdir'] = tempfile.mkdtemp()
atexit.register(lambda: shutil.rmtree(globs['tempdir']))
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py
index 9c70fa5c16d0c..a975dc19cb78e 100644
--- a/python/pyspark/java_gateway.py
+++ b/python/pyspark/java_gateway.py
@@ -45,7 +45,9 @@ def launch_gateway():
# Don't send ctrl-c / SIGINT to the Java gateway:
def preexec_func():
signal.signal(signal.SIGINT, signal.SIG_IGN)
- proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func)
+ env = dict(os.environ)
+ env["IS_SUBPROCESS"] = "1" # tell JVM to exit after python exits
+ proc = Popen(command, stdout=PIPE, stdin=PIPE, preexec_fn=preexec_func, env=env)
else:
# preexec_fn not supported on Windows
proc = Popen(command, stdout=PIPE, stdin=PIPE)
diff --git a/python/pyspark/mllib/__init__.py b/python/pyspark/mllib/__init__.py
index 4149f54931d1f..c3217620e3c4e 100644
--- a/python/pyspark/mllib/__init__.py
+++ b/python/pyspark/mllib/__init__.py
@@ -24,3 +24,12 @@
import numpy
if numpy.version.version < '1.4':
raise Exception("MLlib requires NumPy 1.4+")
+
+__all__ = ['classification', 'clustering', 'feature', 'linalg', 'random',
+ 'recommendation', 'regression', 'stat', 'tree', 'util']
+
+import sys
+import rand as random
+random.__name__ = 'random'
+random.RandomRDDs.__module__ = __name__ + '.random'
+sys.modules[__name__ + '.random'] = random
diff --git a/python/pyspark/mllib/classification.py b/python/pyspark/mllib/classification.py
index 297a2bf37d2cf..f14d0ed11cbbb 100644
--- a/python/pyspark/mllib/classification.py
+++ b/python/pyspark/mllib/classification.py
@@ -20,96 +20,200 @@
import numpy
from numpy import array
+from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
from pyspark.mllib.regression import LabeledPoint, LinearModel, _regression_train_wrapper
-__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'SVMModel',
- 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
+__all__ = ['LogisticRegressionModel', 'LogisticRegressionWithSGD', 'LogisticRegressionWithLBFGS',
+ 'SVMModel', 'SVMWithSGD', 'NaiveBayesModel', 'NaiveBayes']
-class LogisticRegressionModel(LinearModel):
+class LinearBinaryClassificationModel(LinearModel):
+ """
+ Represents a linear binary classification model that predicts to whether an
+ example is positive (1.0) or negative (0.0).
+ """
+ def __init__(self, weights, intercept):
+ super(LinearBinaryClassificationModel, self).__init__(weights, intercept)
+ self._threshold = None
+
+ def setThreshold(self, value):
+ """
+ :: Experimental ::
+
+ Sets the threshold that separates positive predictions from negative
+ predictions. An example with prediction score greater than or equal
+ to this threshold is identified as an positive, and negative otherwise.
+ """
+ self._threshold = value
+
+ def clearThreshold(self):
+ """
+ :: Experimental ::
+
+ Clears the threshold so that `predict` will output raw prediction scores.
+ """
+ self._threshold = None
+
+ def predict(self, test):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ raise NotImplementedError
+
+
+class LogisticRegressionModel(LinearBinaryClassificationModel):
"""A linear binary classification model derived from logistic regression.
>>> data = [
- ... LabeledPoint(0.0, [0.0]),
- ... LabeledPoint(1.0, [1.0]),
- ... LabeledPoint(1.0, [2.0]),
- ... LabeledPoint(1.0, [3.0])
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(data))
- >>> lrm.predict(array([1.0])) > 0
- True
- >>> lrm.predict(array([0.0])) <= 0
- True
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ >>> lrm.predict(sc.parallelize([[1.0, 0.0], [0.0, 1.0]])).collect()
+ [1, 0]
+ >>> lrm.clearThreshold()
+ >>> lrm.predict([0.0, 1.0])
+ 0.123...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
- ... LabeledPoint(0.0, SparseVector(2, {0: 0.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> lrm = LogisticRegressionWithSGD.train(sc.parallelize(sparse_data))
- >>> lrm.predict(array([0.0, 1.0])) > 0
- True
- >>> lrm.predict(array([0.0, 0.0])) <= 0
- True
- >>> lrm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> lrm.predict(SparseVector(2, {1: 0.0})) <= 0
- True
+ >>> lrm.predict(array([0.0, 1.0]))
+ 1
+ >>> lrm.predict(array([1.0, 0.0]))
+ 0
+ >>> lrm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> lrm.predict(SparseVector(2, {0: 1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(LogisticRegressionModel, self).__init__(weights, intercept)
+ self._threshold = 0.5
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self._intercept
if margin > 0:
prob = 1 / (1 + exp(-margin))
else:
exp_margin = exp(margin)
prob = exp_margin / (1 + exp_margin)
- return 1 if prob > 0.5 else 0
+ if self._threshold is None:
+ return prob
+ else:
+ return 1 if prob > self._threshold else 0
class LogisticRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.01, regType="l2", intercept=False):
"""
Train a logistic regression model on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater
- - "none" for no regularizer
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, iterations, step,
- miniBatchFraction, i, regParam, regType, intercept)
+ return callMLlibFunc("trainLogisticRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam), regType,
+ bool(intercept))
+
+ return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
+
+
+class LogisticRegressionWithLBFGS(object):
+
+ @classmethod
+ def train(cls, data, iterations=100, initialWeights=None, regParam=0.01, regType="l2",
+ intercept=False, corrections=10, tolerance=1e-4):
+ """
+ Train a logistic regression model on the given data.
+
+ :param data: The training data, an RDD of LabeledPoint.
+ :param iterations: The number of iterations (default: 100).
+ :param initialWeights: The initial weights (default: None).
+ :param regParam: The regularizer parameter (default: 0.01).
+ :param regType: The type of regularizer used for training
+ our model.
+
+ :Allowed values:
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
+
+ (default: "l2")
+
+ :param intercept: Boolean parameter which indicates the use
+ or not of the augmented representation for
+ training data (i.e. whether bias features
+ are activated or not).
+ :param corrections: The number of corrections used in the LBFGS
+ update (default: 10).
+ :param tolerance: The convergence tolerance of iterations for
+ L-BFGS (default: 1e-4).
+
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0, 1.0]),
+ ... LabeledPoint(1.0, [1.0, 0.0]),
+ ... ]
+ >>> lrm = LogisticRegressionWithLBFGS.train(sc.parallelize(data))
+ >>> lrm.predict([1.0, 0.0])
+ 1
+ >>> lrm.predict([0.0, 1.0])
+ 0
+ """
+ def train(rdd, i):
+ return callMLlibFunc("trainLogisticRegressionModelWithLBFGS", rdd, int(iterations), i,
+ float(regParam), str(regType), bool(intercept), int(corrections),
+ float(tolerance))
return _regression_train_wrapper(train, LogisticRegressionModel, data, initialWeights)
-class SVMModel(LinearModel):
+class SVMModel(LinearBinaryClassificationModel):
"""A support vector machine.
@@ -120,8 +224,14 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, [3.0])
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(data))
- >>> svm.predict(array([1.0])) > 0
- True
+ >>> svm.predict([1.0])
+ 1
+ >>> svm.predict(sc.parallelize([[1.0]])).collect()
+ [1]
+ >>> svm.clearThreshold()
+ >>> svm.predict(array([1.0]))
+ 1.25...
+
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {0: -1.0})),
... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
@@ -129,30 +239,44 @@ class SVMModel(LinearModel):
... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
... ]
>>> svm = SVMWithSGD.train(sc.parallelize(sparse_data))
- >>> svm.predict(SparseVector(2, {1: 1.0})) > 0
- True
- >>> svm.predict(SparseVector(2, {0: -1.0})) <= 0
- True
+ >>> svm.predict(SparseVector(2, {1: 1.0}))
+ 1
+ >>> svm.predict(SparseVector(2, {0: -1.0}))
+ 0
"""
+ def __init__(self, weights, intercept):
+ super(SVMModel, self).__init__(weights, intercept)
+ self._threshold = 0.0
def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
+
+ x = _convert_to_vector(x)
margin = self.weights.dot(x) + self.intercept
- return 1 if margin >= 0 else 0
+ if self._threshold is None:
+ return margin
+ else:
+ return 1 if margin > self._threshold else 0
class SVMWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
- miniBatchFraction=1.0, initialWeights=None, regType="none", intercept=False):
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
+ miniBatchFraction=1.0, initialWeights=None, regType="l2", intercept=False):
"""
Train a support vector machine on the given data.
- :param data: The training data.
+ :param data: The training data, an RDD of LabeledPoint.
:param iterations: The number of iterations (default: 100).
:param step: The step parameter used in SGD
(default: 1.0).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.01).
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
@@ -160,20 +284,21 @@ def train(cls, data, iterations=100, step=1.0, regParam=1.0,
our model.
:Allowed values:
- - "l1" for using L1Updater
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization
+ - "l2" for using L2 regularization
+ - None for no regularization
- (default: "none")
+ (default: "l2")
- @param intercept: Boolean parameter which indicates the use
+ :param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
training data (i.e. whether bias features
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainSVMModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i, regType, intercept)
+ return callMLlibFunc("trainSVMModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i, regType,
+ bool(intercept))
return _regression_train_wrapper(train, SVMModel, data, initialWeights)
@@ -197,6 +322,8 @@ class NaiveBayesModel(object):
0.0
>>> model.predict(array([1.0, 0.0]))
1.0
+ >>> model.predict(sc.parallelize([[1.0, 0.0]])).collect()
+ [1.0]
>>> sparse_data = [
... LabeledPoint(0.0, SparseVector(2, {1: 0.0})),
... LabeledPoint(0.0, SparseVector(2, {1: 1.0})),
@@ -215,7 +342,9 @@ def __init__(self, labels, pi, theta):
self.theta = theta
def predict(self, x):
- """Return the most likely class for a data vector x"""
+ """Return the most likely class for a data vector or an RDD of vectors"""
+ if isinstance(x, RDD):
+ return x.map(lambda v: self.predict(v))
x = _convert_to_vector(x)
return self.labels[numpy.argmax(self.pi + x.dot(self.theta.transpose()))]
@@ -233,11 +362,12 @@ def train(cls, data, lambda_=1.0):
classification. By making every vector a 0-1 vector, it can also be
used as Bernoulli NB (U{http://tinyurl.com/p7c96j6}).
- :param data: RDD of NumPy vectors, one per element, where the first
- coordinate is the label and the rest is the feature vector
- (e.g. a count vector).
+ :param data: RDD of LabeledPoint.
:param lambda_: The smoothing parameter
"""
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("`data` should be an RDD of LabeledPoint")
labels, pi, theta = callMLlibFunc("trainNaiveBayes", data, lambda_)
return NaiveBayesModel(labels.toArray(), pi.toArray(), numpy.array(theta))
@@ -245,7 +375,8 @@ def train(cls, data, lambda_=1.0):
def _test():
import doctest
from pyspark import SparkContext
- globs = globals().copy()
+ import pyspark.mllib.classification
+ globs = pyspark.mllib.classification.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/mllib/clustering.py b/python/pyspark/mllib/clustering.py
index fe4c4cc5094d8..e2492eef5bd6a 100644
--- a/python/pyspark/mllib/clustering.py
+++ b/python/pyspark/mllib/clustering.py
@@ -16,7 +16,7 @@
#
from pyspark import SparkContext
-from pyspark.mllib.common import callMLlibFunc, callJavaFunc, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc, callJavaFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['KMeansModel', 'KMeans']
@@ -80,10 +80,8 @@ class KMeans(object):
@classmethod
def train(cls, rdd, k, maxIterations=100, runs=1, initializationMode="k-means||"):
"""Train a k-means clustering model."""
- # cache serialized data to avoid objects over head in JVM
- jcached = _to_java_object_rdd(rdd.map(_convert_to_vector), cache=True)
- model = callMLlibFunc("trainKMeansModel", jcached, k, maxIterations, runs,
- initializationMode)
+ model = callMLlibFunc("trainKMeansModel", rdd.map(_convert_to_vector), k, maxIterations,
+ runs, initializationMode)
centers = callJavaFunc(rdd.context, model.clusterCenters)
return KMeansModel([c.toArray() for c in centers])
diff --git a/python/pyspark/mllib/common.py b/python/pyspark/mllib/common.py
index 76864d8163586..3c5ee66cd8b64 100644
--- a/python/pyspark/mllib/common.py
+++ b/python/pyspark/mllib/common.py
@@ -18,7 +18,7 @@
import py4j.protocol
from py4j.protocol import Py4JJavaError
from py4j.java_gateway import JavaObject
-from py4j.java_collections import MapConverter, ListConverter, JavaArray, JavaList
+from py4j.java_collections import ListConverter, JavaArray, JavaList
from pyspark import RDD, SparkContext
from pyspark.serializers import PickleSerializer, AutoBatchedSerializer
@@ -54,15 +54,13 @@ def _new_smart_decode(obj):
# this will call the MLlib version of pythonToJava()
-def _to_java_object_rdd(rdd, cache=False):
+def _to_java_object_rdd(rdd):
""" Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pyrolite, whenever the
RDD is serialized in batch or not.
"""
rdd = rdd._reserialize(AutoBatchedSerializer(PickleSerializer()))
- if cache:
- rdd.cache()
return rdd.ctx._jvm.SerDe.pythonToJava(rdd._jrdd, True)
@@ -72,9 +70,7 @@ def _py2java(sc, obj):
obj = _to_java_object_rdd(obj)
elif isinstance(obj, SparkContext):
obj = obj._jsc
- elif isinstance(obj, dict):
- obj = MapConverter().convert(obj, sc._gateway._gateway_client)
- elif isinstance(obj, (list, tuple)):
+ elif isinstance(obj, list) and (obj or isinstance(obj[0], JavaObject)):
obj = ListConverter().convert(obj, sc._gateway._gateway_client)
elif isinstance(obj, JavaObject):
pass
@@ -96,10 +92,15 @@ def _java2py(sc, r):
if clsName == 'JavaRDD':
jrdd = sc._jvm.SerDe.javaToPython(r)
- return RDD(jrdd, sc, AutoBatchedSerializer(PickleSerializer()))
+ return RDD(jrdd, sc)
- elif isinstance(r, (JavaArray, JavaList)) or clsName in _picklable_classes:
+ if clsName in _picklable_classes:
r = sc._jvm.SerDe.dumps(r)
+ elif isinstance(r, (JavaArray, JavaList)):
+ try:
+ r = sc._jvm.SerDe.dumps(r)
+ except Py4JJavaError:
+ pass # not pickable
if isinstance(r, bytearray):
r = PickleSerializer().loads(str(r))
diff --git a/python/pyspark/mllib/feature.py b/python/pyspark/mllib/feature.py
index 44bf6f269d7a3..7f532139272d3 100644
--- a/python/pyspark/mllib/feature.py
+++ b/python/pyspark/mllib/feature.py
@@ -18,14 +18,17 @@
"""
Python package for feature in MLlib.
"""
+from __future__ import absolute_import
+
import sys
import warnings
+import random
from py4j.protocol import Py4JJavaError
from pyspark import RDD, SparkContext
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import Vectors
+from pyspark.mllib.linalg import Vectors, _convert_to_vector
__all__ = ['Normalizer', 'StandardScalerModel', 'StandardScaler',
'HashingTF', 'IDFModel', 'IDF', 'Word2Vec', 'Word2VecModel']
@@ -50,10 +53,10 @@ class Normalizer(VectorTransformer):
"""
:: Experimental ::
- Normalizes samples individually to unit L\ :sup:`p`\ norm
+ Normalizes samples individually to unit L\ :sup:`p`\ norm
- For any 1 <= `p` <= float('inf'), normalizes samples using
- sum(abs(vector). :sup:`p`) :sup:`(1/p)` as norm.
+ For any 1 <= `p` < float('inf'), normalizes samples using
+ sum(abs(vector) :sup:`p`) :sup:`(1/p)` as norm.
For `p` = float('inf'), max(abs(vector)) will be used as norm for normalization.
@@ -81,12 +84,16 @@ def transform(self, vector):
"""
Applies unit length normalization on a vector.
- :param vector: vector to be normalized.
+ :param vector: vector or RDD of vector to be normalized.
:return: normalized vector. If the norm of the input is zero, it
will return the input vector.
"""
sc = SparkContext._active_spark_context
assert sc is not None, "SparkContext should be initialized first"
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
return callMLlibFunc("normalizeVector", self.p, vector)
@@ -95,8 +102,12 @@ class JavaVectorTransformer(JavaModelWrapper, VectorTransformer):
Wrapper for the model in JVM
"""
- def transform(self, dataset):
- return self.call("transform", dataset)
+ def transform(self, vector):
+ if isinstance(vector, RDD):
+ vector = vector.map(_convert_to_vector)
+ else:
+ vector = _convert_to_vector(vector)
+ return self.call("transform", vector)
class StandardScalerModel(JavaVectorTransformer):
@@ -109,7 +120,7 @@ def transform(self, vector):
"""
Applies standardization transformation on a vector.
- :param vector: Vector to be standardized.
+ :param vector: Vector or RDD of Vector to be standardized.
:return: Standardized vector. If the variance of a column is zero,
it will return default `0.0` for the column with zero variance.
"""
@@ -154,6 +165,7 @@ def fit(self, dataset):
the transformation model.
:return: a StandardScalarModel
"""
+ dataset = dataset.map(_convert_to_vector)
jmodel = callMLlibFunc("fitStandardScaler", self.withMean, self.withStd, dataset)
return StandardScalerModel(jmodel)
@@ -211,6 +223,8 @@ def transform(self, dataset):
:param dataset: an RDD of term frequency vectors
:return: an RDD of TF-IDF vectors
"""
+ if not isinstance(dataset, RDD):
+ raise TypeError("dataset should be an RDD of term frequency vectors")
return JavaVectorTransformer.transform(self, dataset)
@@ -255,7 +269,9 @@ def fit(self, dataset):
:param dataset: an RDD of term frequency vectors
"""
- jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset)
+ if not isinstance(dataset, RDD):
+ raise TypeError("dataset should be an RDD of term frequency vectors")
+ jmodel = callMLlibFunc("fitIDF", self.minDocFreq, dataset.map(_convert_to_vector))
return IDFModel(jmodel)
@@ -287,6 +303,8 @@ def findSynonyms(self, word, num):
Note: local use only
"""
+ if not isinstance(word, basestring):
+ word = _convert_to_vector(word)
words, similarity = self.call("findSynonyms", word, num)
return zip(words, similarity)
@@ -326,8 +344,6 @@ def __init__(self):
"""
Construct Word2Vec instance
"""
- import random # this can't be on the top because of mllib.random
-
self.vectorSize = 100
self.learningRate = 0.025
self.numPartitions = 1
@@ -374,9 +390,11 @@ def fit(self, data):
"""
Computes the vector representation of each word in vocabulary.
- :param data: training data. RDD of subtype of Iterable[String]
+ :param data: training data. RDD of list of string
:return: Word2VecModel instance
"""
+ if not isinstance(data, RDD):
+ raise TypeError("data should be an RDD of list of string")
jmodel = callMLlibFunc("trainWord2Vec", data, int(self.vectorSize),
float(self.learningRate), int(self.numPartitions),
int(self.numIterations), long(self.seed))
@@ -394,8 +412,5 @@ def _test():
exit(-1)
if __name__ == "__main__":
- # remove current path from list of search paths to avoid importing mllib.random
- # for C{import random}, which is done in an external dependency of pyspark during doctests.
- import sys
sys.path.pop(0)
_test()
diff --git a/python/pyspark/mllib/linalg.py b/python/pyspark/mllib/linalg.py
index d0a0e102a1a07..4f8491f43e457 100644
--- a/python/pyspark/mllib/linalg.py
+++ b/python/pyspark/mllib/linalg.py
@@ -29,8 +29,11 @@
import numpy as np
+from pyspark.sql import UserDefinedType, StructField, StructType, ArrayType, DoubleType, \
+ IntegerType, ByteType
-__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors']
+
+__all__ = ['Vector', 'DenseVector', 'SparseVector', 'Vectors', 'DenseMatrix', 'Matrices']
if sys.version_info[:2] == (2, 7):
@@ -106,7 +109,54 @@ def _format_float(f, digits=4):
return s
+class VectorUDT(UserDefinedType):
+ """
+ SQL user-defined type (UDT) for Vector.
+ """
+
+ @classmethod
+ def sqlType(cls):
+ return StructType([
+ StructField("type", ByteType(), False),
+ StructField("size", IntegerType(), True),
+ StructField("indices", ArrayType(IntegerType(), False), True),
+ StructField("values", ArrayType(DoubleType(), False), True)])
+
+ @classmethod
+ def module(cls):
+ return "pyspark.mllib.linalg"
+
+ @classmethod
+ def scalaUDT(cls):
+ return "org.apache.spark.mllib.linalg.VectorUDT"
+
+ def serialize(self, obj):
+ if isinstance(obj, SparseVector):
+ indices = [int(i) for i in obj.indices]
+ values = [float(v) for v in obj.values]
+ return (0, obj.size, indices, values)
+ elif isinstance(obj, DenseVector):
+ values = [float(v) for v in obj]
+ return (1, None, None, values)
+ else:
+ raise ValueError("cannot serialize %r of type %r" % (obj, type(obj)))
+
+ def deserialize(self, datum):
+ assert len(datum) == 4, \
+ "VectorUDT.deserialize given row with length %d but requires 4" % len(datum)
+ tpe = datum[0]
+ if tpe == 0:
+ return SparseVector(datum[1], datum[2], datum[3])
+ elif tpe == 1:
+ return DenseVector(datum[3])
+ else:
+ raise ValueError("do not recognize type %r" % tpe)
+
+
class Vector(object):
+
+ __UDT__ = VectorUDT()
+
"""
Abstract class for DenseVector and SparseVector
"""
@@ -123,12 +173,16 @@ class DenseVector(Vector):
A dense vector represented by a value array.
"""
def __init__(self, ar):
- if not isinstance(ar, array.array):
- ar = array.array('d', ar)
+ if isinstance(ar, basestring):
+ ar = np.frombuffer(ar, dtype=np.float64)
+ elif not isinstance(ar, np.ndarray):
+ ar = np.array(ar, dtype=np.float64)
+ if ar.dtype != np.float64:
+ ar = ar.astype(np.float64)
self.array = ar
def __reduce__(self):
- return DenseVector, (self.array,)
+ return DenseVector, (self.array.tostring(),)
def dot(self, other):
"""
@@ -157,9 +211,10 @@ def dot(self, other):
...
AssertionError: dimension mismatch
"""
- if type(other) == np.ndarray and other.ndim > 1:
- assert len(self) == other.shape[0], "dimension mismatch"
- return np.dot(self.toArray(), other)
+ if type(other) == np.ndarray:
+ if other.ndim > 1:
+ assert len(self) == other.shape[0], "dimension mismatch"
+ return np.dot(self.array, other)
elif _have_scipy and scipy.sparse.issparse(other):
assert len(self) == other.shape[0], "dimension mismatch"
return other.transpose().dot(self.toArray())
@@ -211,7 +266,7 @@ def squared_distance(self, other):
return np.dot(diff, diff)
def toArray(self):
- return np.array(self.array)
+ return self.array
def __getitem__(self, item):
return self.array[item]
@@ -226,7 +281,7 @@ def __repr__(self):
return "DenseVector([%s])" % (', '.join(_format_float(i) for i in self.array))
def __eq__(self, other):
- return isinstance(other, DenseVector) and self.array == other.array
+ return isinstance(other, DenseVector) and np.array_equal(self.array, other.array)
def __ne__(self, other):
return not self == other
@@ -264,18 +319,28 @@ def __init__(self, size, *args):
if type(pairs) == dict:
pairs = pairs.items()
pairs = sorted(pairs)
- self.indices = array.array('i', [p[0] for p in pairs])
- self.values = array.array('d', [p[1] for p in pairs])
+ self.indices = np.array([p[0] for p in pairs], dtype=np.int32)
+ self.values = np.array([p[1] for p in pairs], dtype=np.float64)
else:
- assert len(args[0]) == len(args[1]), "index and value arrays not same length"
- self.indices = array.array('i', args[0])
- self.values = array.array('d', args[1])
+ if isinstance(args[0], basestring):
+ assert isinstance(args[1], str), "values should be string too"
+ if args[0]:
+ self.indices = np.frombuffer(args[0], np.int32)
+ self.values = np.frombuffer(args[1], np.float64)
+ else:
+ # np.frombuffer() doesn't work well with empty string in older version
+ self.indices = np.array([], dtype=np.int32)
+ self.values = np.array([], dtype=np.float64)
+ else:
+ self.indices = np.array(args[0], dtype=np.int32)
+ self.values = np.array(args[1], dtype=np.float64)
+ assert len(self.indices) == len(self.values), "index and value arrays not same length"
for i in xrange(len(self.indices) - 1):
if self.indices[i] >= self.indices[i + 1]:
raise TypeError("indices array must be sorted")
def __reduce__(self):
- return (SparseVector, (self.size, self.indices, self.values))
+ return (SparseVector, (self.size, self.indices.tostring(), self.values.tostring()))
def dot(self, other):
"""
@@ -411,8 +476,7 @@ def toArray(self):
Returns a copy of this SparseVector as a 1-dimensional NumPy array.
"""
arr = np.zeros((self.size,), dtype=np.float64)
- for i in xrange(len(self.indices)):
- arr[self.indices[i]] = self.values[i]
+ arr[self.indices] = self.values
return arr
def __len__(self):
@@ -443,8 +507,8 @@ def __eq__(self, other):
"""
return (isinstance(other, self.__class__)
and other.size == self.size
- and other.indices == self.indices
- and other.values == self.values)
+ and np.array_equal(other.indices, self.indices)
+ and np.array_equal(other.values, self.values))
def __ne__(self, other):
return not self.__eq__(other)
@@ -527,23 +591,43 @@ class DenseMatrix(Matrix):
"""
def __init__(self, numRows, numCols, values):
Matrix.__init__(self, numRows, numCols)
+ if isinstance(values, basestring):
+ values = np.frombuffer(values, dtype=np.float64)
+ elif not isinstance(values, np.ndarray):
+ values = np.array(values, dtype=np.float64)
assert len(values) == numRows * numCols
+ if values.dtype != np.float64:
+ values.astype(np.float64)
self.values = values
def __reduce__(self):
- return DenseMatrix, (self.numRows, self.numCols, self.values)
+ return DenseMatrix, (self.numRows, self.numCols, self.values.tostring())
def toArray(self):
"""
Return an numpy.ndarray
- >>> arr = array.array('d', [float(i) for i in range(4)])
- >>> m = DenseMatrix(2, 2, arr)
+ >>> m = DenseMatrix(2, 2, range(4))
>>> m.toArray()
array([[ 0., 2.],
[ 1., 3.]])
"""
- return np.reshape(self.values, (self.numRows, self.numCols), order='F')
+ return self.values.reshape((self.numRows, self.numCols), order='F')
+
+ def __eq__(self, other):
+ return (isinstance(other, DenseMatrix) and
+ self.numRows == other.numRows and
+ self.numCols == other.numCols and
+ all(self.values == other.values))
+
+
+class Matrices(object):
+ @staticmethod
+ def dense(numRows, numCols, values):
+ """
+ Create a DenseMatrix
+ """
+ return DenseMatrix(numRows, numCols, values)
def _test():
@@ -553,8 +637,4 @@ def _test():
exit(-1)
if __name__ == "__main__":
- # remove current path from list of search paths to avoid importing mllib.random
- # for C{import random}, which is done in an external dependency of pyspark during doctests.
- import sys
- sys.path.pop(0)
_test()
diff --git a/python/pyspark/mllib/random.py b/python/pyspark/mllib/rand.py
similarity index 69%
rename from python/pyspark/mllib/random.py
rename to python/pyspark/mllib/rand.py
index 7eebfc6bcd894..cb4304f92152b 100644
--- a/python/pyspark/mllib/random.py
+++ b/python/pyspark/mllib/rand.py
@@ -52,6 +52,12 @@ def uniformRDD(sc, size, numPartitions=None, seed=None):
C{RandomRDDs.uniformRDD(sc, n, p, seed)\
.map(lambda v: a + (b - a) * v)}
+ :param sc: SparkContext used to create the RDD.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ `U(0.0, 1.0)`.
+
>>> x = RandomRDDs.uniformRDD(sc, 100).collect()
>>> len(x)
100
@@ -76,6 +82,12 @@ def normalRDD(sc, size, numPartitions=None, seed=None):
C{RandomRDDs.normal(sc, n, p, seed)\
.map(lambda v: mean + sigma * v)}
+ :param sc: SparkContext used to create the RDD.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ N(0.0, 1.0).
+
>>> x = RandomRDDs.normalRDD(sc, 1000, seed=1L)
>>> stats = x.stats()
>>> stats.count()
@@ -93,6 +105,13 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
Generates an RDD comprised of i.i.d. samples from the Poisson
distribution with the input mean.
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or lambda, for the Poisson distribution.
+ :param size: Size of the RDD.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of float comprised of i.i.d. samples ~ Pois(mean).
+
>>> mean = 100.0
>>> x = RandomRDDs.poissonRDD(sc, mean, 1000, seed=2L)
>>> stats = x.stats()
@@ -104,7 +123,7 @@ def poissonRDD(sc, mean, size, numPartitions=None, seed=None):
>>> abs(stats.stdev() - sqrt(mean)) < 0.5
True
"""
- return callMLlibFunc("poissonRDD", sc._jsc, mean, size, numPartitions, seed)
+ return callMLlibFunc("poissonRDD", sc._jsc, float(mean), size, numPartitions, seed)
@staticmethod
@toArray
@@ -113,6 +132,13 @@ def uniformVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
Generates an RDD comprised of vectors containing i.i.d. samples drawn
from the uniform distribution U(0.0, 1.0).
+ :param sc: SparkContext used to create the RDD.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD.
+ :param seed: Seed for the RNG that generates the seed for the generator in each partition.
+ :return: RDD of Vector with vectors containing i.i.d samples ~ `U(0.0, 1.0)`.
+
>>> import numpy as np
>>> mat = np.matrix(RandomRDDs.uniformVectorRDD(sc, 10, 10).collect())
>>> mat.shape
@@ -131,6 +157,13 @@ def normalVectorRDD(sc, numRows, numCols, numPartitions=None, seed=None):
Generates an RDD comprised of vectors containing i.i.d. samples drawn
from the standard normal distribution.
+ :param sc: SparkContext used to create the RDD.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`).
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ `N(0.0, 1.0)`.
+
>>> import numpy as np
>>> mat = np.matrix(RandomRDDs.normalVectorRDD(sc, 100, 100, seed=1L).collect())
>>> mat.shape
@@ -149,6 +182,14 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
Generates an RDD comprised of vectors containing i.i.d. samples drawn
from the Poisson distribution with the input mean.
+ :param sc: SparkContext used to create the RDD.
+ :param mean: Mean, or lambda, for the Poisson distribution.
+ :param numRows: Number of Vectors in the RDD.
+ :param numCols: Number of elements in each Vector.
+ :param numPartitions: Number of partitions in the RDD (default: `sc.defaultParallelism`)
+ :param seed: Random seed (default: a random long integer).
+ :return: RDD of Vector with vectors containing i.i.d. samples ~ Pois(mean).
+
>>> import numpy as np
>>> mean = 100.0
>>> rdd = RandomRDDs.poissonVectorRDD(sc, mean, 100, 100, seed=1L)
@@ -161,7 +202,7 @@ def poissonVectorRDD(sc, mean, numRows, numCols, numPartitions=None, seed=None):
>>> abs(mat.std() - sqrt(mean)) < 0.5
True
"""
- return callMLlibFunc("poissonVectorRDD", sc._jsc, mean, numRows, numCols,
+ return callMLlibFunc("poissonVectorRDD", sc._jsc, float(mean), numRows, numCols,
numPartitions, seed)
diff --git a/python/pyspark/mllib/recommendation.py b/python/pyspark/mllib/recommendation.py
index 6b32af07c9be2..97ec74eda0b71 100644
--- a/python/pyspark/mllib/recommendation.py
+++ b/python/pyspark/mllib/recommendation.py
@@ -15,24 +15,28 @@
# limitations under the License.
#
+from collections import namedtuple
+
from pyspark import SparkContext
from pyspark.rdd import RDD
-from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.common import JavaModelWrapper, callMLlibFunc
-__all__ = ['MatrixFactorizationModel', 'ALS']
+__all__ = ['MatrixFactorizationModel', 'ALS', 'Rating']
-class Rating(object):
- def __init__(self, user, product, rating):
- self.user = int(user)
- self.product = int(product)
- self.rating = float(rating)
+class Rating(namedtuple("Rating", ["user", "product", "rating"])):
+ """
+ Represents a (user, product, rating) tuple.
- def __reduce__(self):
- return Rating, (self.user, self.product, self.rating)
+ >>> r = Rating(1, 2, 5.0)
+ >>> (r.user, r.product, r.rating)
+ (1, 2, 5.0)
+ >>> (r[0], r[1], r[2])
+ (1, 2, 5.0)
+ """
- def __repr__(self):
- return "Rating(%d, %d, %d)" % (self.user, self.product, self.rating)
+ def __reduce__(self):
+ return Rating, (int(self.user), int(self.product), float(self.rating))
class MatrixFactorizationModel(JavaModelWrapper):
@@ -44,34 +48,42 @@ class MatrixFactorizationModel(JavaModelWrapper):
>>> r2 = (1, 2, 2.0)
>>> r3 = (2, 1, 2.0)
>>> ratings = sc.parallelize([r1, r2, r3])
- >>> model = ALS.trainImplicit(ratings, 1)
- >>> model.predict(2,2) is not None
- True
+ >>> model = ALS.trainImplicit(ratings, 1, seed=10)
+ >>> model.predict(2,2)
+ 0.4473...
>>> testset = sc.parallelize([(1, 2), (1, 1)])
- >>> model = ALS.train(ratings, 1)
- >>> model.predictAll(testset).count() == 2
- True
+ >>> model = ALS.train(ratings, 1, seed=10)
+ >>> model.predictAll(testset).collect()
+ [Rating(user=1, product=1, rating=1.0471...), Rating(user=1, product=2, rating=1.9679...)]
- >>> model = ALS.train(ratings, 4)
- >>> model.userFeatures().count() == 2
- True
+ >>> model = ALS.train(ratings, 4, seed=10)
+ >>> model.userFeatures().collect()
+ [(2, array('d', [...])), (1, array('d', [...]))]
>>> first_user = model.userFeatures().take(1)[0]
>>> latents = first_user[1]
>>> len(latents) == 4
True
- >>> model.productFeatures().count() == 2
- True
+ >>> model.productFeatures().collect()
+ [(2, array('d', [...])), (1, array('d', [...]))]
>>> first_product = model.productFeatures().take(1)[0]
>>> latents = first_product[1]
>>> len(latents) == 4
True
+
+ >>> model = ALS.train(ratings, 1, nonnegative=True, seed=10)
+ >>> model.predict(2,2)
+ 3.735...
+
+ >>> model = ALS.trainImplicit(ratings, 1, nonnegative=True, seed=10)
+ >>> model.predict(2,2)
+ 0.4473...
"""
def predict(self, user, product):
- return self._java_model.predict(user, product)
+ return self._java_model.predict(int(user), int(product))
def predictAll(self, user_product):
assert isinstance(user_product, RDD), "user_product should be RDD of (user, product)"
@@ -98,18 +110,20 @@ def _prepare(cls, ratings):
ratings = ratings.map(lambda x: Rating(*x))
else:
raise ValueError("rating should be RDD of Rating or tuple/list")
- return _to_java_object_rdd(ratings, True)
+ return ratings
@classmethod
- def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1):
+ def train(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, nonnegative=False,
+ seed=None):
model = callMLlibFunc("trainALSModel", cls._prepare(ratings), rank, iterations,
- lambda_, blocks)
+ lambda_, blocks, nonnegative, seed)
return MatrixFactorizationModel(model)
@classmethod
- def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01):
+ def trainImplicit(cls, ratings, rank, iterations=5, lambda_=0.01, blocks=-1, alpha=0.01,
+ nonnegative=False, seed=None):
model = callMLlibFunc("trainImplicitALSModel", cls._prepare(ratings), rank,
- iterations, lambda_, blocks, alpha)
+ iterations, lambda_, blocks, alpha, nonnegative, seed)
return MatrixFactorizationModel(model)
@@ -117,7 +131,7 @@ def _test():
import doctest
import pyspark.mllib.recommendation
globs = pyspark.mllib.recommendation.__dict__.copy()
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
if failure_count:
diff --git a/python/pyspark/mllib/regression.py b/python/pyspark/mllib/regression.py
index 43c1a2fc101dd..210060140fd91 100644
--- a/python/pyspark/mllib/regression.py
+++ b/python/pyspark/mllib/regression.py
@@ -18,7 +18,7 @@
import numpy as np
from numpy import array
-from pyspark.mllib.common import callMLlibFunc, _to_java_object_rdd
+from pyspark.mllib.common import callMLlibFunc
from pyspark.mllib.linalg import SparseVector, _convert_to_vector
__all__ = ['LabeledPoint', 'LinearModel', 'LinearRegressionModel', 'RidgeRegressionModel',
@@ -36,7 +36,7 @@ class LabeledPoint(object):
"""
def __init__(self, label, features):
- self.label = label
+ self.label = float(label)
self.features = _convert_to_vector(features)
def __reduce__(self):
@@ -46,7 +46,7 @@ def __str__(self):
return "(" + ",".join((str(self.label), str(self.features))) + ")"
def __repr__(self):
- return "LabeledPoint(" + ",".join((repr(self.label), repr(self.features))) + ")"
+ return "LabeledPoint(%s, %s)" % (self.label, self.features)
class LinearModel(object):
@@ -55,7 +55,7 @@ class LinearModel(object):
def __init__(self, weights, intercept):
self._coeff = _convert_to_vector(weights)
- self._intercept = intercept
+ self._intercept = float(intercept)
@property
def weights(self):
@@ -66,7 +66,7 @@ def intercept(self):
return self._intercept
def __repr__(self):
- return "(weights=%s, intercept=%s)" % (self._coeff, self._intercept)
+ return "(weights=%s, intercept=%r)" % (self._coeff, self._intercept)
class LinearRegressionModelBase(LinearModel):
@@ -85,6 +85,7 @@ def predict(self, x):
Predict the value of the dependent variable given a vector x
containing values for the independent variables.
"""
+ x = _convert_to_vector(x)
return self.weights.dot(x) + self.intercept
@@ -124,9 +125,11 @@ class LinearRegressionModel(LinearRegressionModelBase):
# return the result of a call to the appropriate JVM stub.
# _regression_train_wrapper is responsible for setup and error checking.
def _regression_train_wrapper(train_func, modelClass, data, initial_weights):
+ first = data.first()
+ if not isinstance(first, LabeledPoint):
+ raise ValueError("data should be an RDD of LabeledPoint, but got %s" % first)
initial_weights = initial_weights or [0.0] * len(data.first().features)
- weights, intercept = train_func(_to_java_object_rdd(data, cache=True),
- _convert_to_vector(initial_weights))
+ weights, intercept = train_func(data, _convert_to_vector(initial_weights))
return modelClass(weights, intercept)
@@ -134,7 +137,7 @@ class LinearRegressionWithSGD(object):
@classmethod
def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
- initialWeights=None, regParam=1.0, regType="none", intercept=False):
+ initialWeights=None, regParam=0.0, regType=None, intercept=False):
"""
Train a linear regression model on the given data.
@@ -145,16 +148,16 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
:param miniBatchFraction: Fraction of data to be used for each SGD
iteration.
:param initialWeights: The initial weights (default: None).
- :param regParam: The regularizer parameter (default: 1.0).
+ :param regParam: The regularizer parameter (default: 0.0).
:param regType: The type of regularizer used for training
our model.
:Allowed values:
- - "l1" for using L1Updater,
- - "l2" for using SquaredL2Updater,
- - "none" for no regularizer.
+ - "l1" for using L1 regularization (lasso),
+ - "l2" for using L2 regularization (ridge),
+ - None for no regularization
- (default: "none")
+ (default: None)
@param intercept: Boolean parameter which indicates the use
or not of the augmented representation for
@@ -162,11 +165,11 @@ def train(cls, data, iterations=100, step=1.0, miniBatchFraction=1.0,
are activated or not).
"""
def train(rdd, i):
- return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, iterations, step,
- miniBatchFraction, i, regParam, regType, intercept)
+ return callMLlibFunc("trainLinearRegressionModelWithSGD", rdd, int(iterations),
+ float(step), float(miniBatchFraction), i, float(regParam),
+ regType, bool(intercept))
- return _regression_train_wrapper(train, LinearRegressionModel,
- data, initialWeights)
+ return _regression_train_wrapper(train, LinearRegressionModel, data, initialWeights)
class LassoModel(LinearRegressionModelBase):
@@ -205,12 +208,13 @@ class LassoModel(LinearRegressionModelBase):
class LassoWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a Lasso regression model on the given data."""
def train(rdd, i):
- return callMLlibFunc("trainLassoModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i)
+ return callMLlibFunc("trainLassoModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
+
return _regression_train_wrapper(train, LassoModel, data, initialWeights)
@@ -250,21 +254,21 @@ class RidgeRegressionModel(LinearRegressionModelBase):
class RidgeRegressionWithSGD(object):
@classmethod
- def train(cls, data, iterations=100, step=1.0, regParam=1.0,
+ def train(cls, data, iterations=100, step=1.0, regParam=0.01,
miniBatchFraction=1.0, initialWeights=None):
"""Train a ridge regression model on the given data."""
def train(rdd, i):
- return callMLlibFunc("trainRidgeModelWithSGD", rdd, iterations, step, regParam,
- miniBatchFraction, i)
+ return callMLlibFunc("trainRidgeModelWithSGD", rdd, int(iterations), float(step),
+ float(regParam), float(miniBatchFraction), i)
- return _regression_train_wrapper(train, RidgeRegressionModel,
- data, initialWeights)
+ return _regression_train_wrapper(train, RidgeRegressionModel, data, initialWeights)
def _test():
import doctest
from pyspark import SparkContext
- globs = globals().copy()
+ import pyspark.mllib.regression
+ globs = pyspark.mllib.regression.__dict__.copy()
globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
(failure_count, test_count) = doctest.testmod(globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/mllib/stat.py b/python/pyspark/mllib/stat.py
index 15f0652f833d7..1980f5b03f430 100644
--- a/python/pyspark/mllib/stat.py
+++ b/python/pyspark/mllib/stat.py
@@ -19,11 +19,13 @@
Python package for statistical functions in MLlib.
"""
+from pyspark import RDD
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
-from pyspark.mllib.linalg import _convert_to_vector
+from pyspark.mllib.linalg import Matrix, _convert_to_vector
+from pyspark.mllib.regression import LabeledPoint
-__all__ = ['MultivariateStatisticalSummary', 'Statistics']
+__all__ = ['MultivariateStatisticalSummary', 'ChiSqTestResult', 'Statistics']
class MultivariateStatisticalSummary(JavaModelWrapper):
@@ -51,6 +53,54 @@ def min(self):
return self.call("min").toArray()
+class ChiSqTestResult(JavaModelWrapper):
+ """
+ :: Experimental ::
+
+ Object containing the test results for the chi-squared hypothesis test.
+ """
+ @property
+ def method(self):
+ """
+ Name of the test method
+ """
+ return self._java_model.method()
+
+ @property
+ def pValue(self):
+ """
+ The probability of obtaining a test statistic result at least as
+ extreme as the one that was actually observed, assuming that the
+ null hypothesis is true.
+ """
+ return self._java_model.pValue()
+
+ @property
+ def degreesOfFreedom(self):
+ """
+ Returns the degree(s) of freedom of the hypothesis test.
+ Return type should be Number(e.g. Int, Double) or tuples of Numbers.
+ """
+ return self._java_model.degreesOfFreedom()
+
+ @property
+ def statistic(self):
+ """
+ Test statistic.
+ """
+ return self._java_model.statistic()
+
+ @property
+ def nullHypothesis(self):
+ """
+ Null hypothesis of the test.
+ """
+ return self._java_model.nullHypothesis()
+
+ def __str__(self):
+ return self._java_model.toString()
+
+
class Statistics(object):
@staticmethod
@@ -58,6 +108,11 @@ def colStats(rdd):
"""
Computes column-wise summary statistics for the input RDD[Vector].
+ :param rdd: an RDD[Vector] for which column-wise summary statistics
+ are to be computed.
+ :return: :class:`MultivariateStatisticalSummary` object containing
+ column-wise summary statistics.
+
>>> from pyspark.mllib.linalg import Vectors
>>> rdd = sc.parallelize([Vectors.dense([2, 0, 0, -2]),
... Vectors.dense([4, 5, 0, 3]),
@@ -91,6 +146,13 @@ def corr(x, y=None, method=None):
to specify the method to be used for single RDD inout.
If two RDDs of floats are passed in, a single float is returned.
+ :param x: an RDD of vector for which the correlation matrix is to be computed,
+ or an RDD of float of the same cardinality as y when y is specified.
+ :param y: an RDD of float of the same cardinality as x.
+ :param method: String specifying the method to use for computing correlation.
+ Supported: `pearson` (default), `spearman`
+ :return: Correlation matrix comparing columns in x.
+
>>> x = sc.parallelize([1.0, 0.0, -2.0], 2)
>>> y = sc.parallelize([4.0, 5.0, 3.0], 2)
>>> zeros = sc.parallelize([0.0, 0.0, 0.0], 2)
@@ -135,6 +197,91 @@ def corr(x, y=None, method=None):
else:
return callMLlibFunc("corr", x.map(float), y.map(float), method)
+ @staticmethod
+ def chiSqTest(observed, expected=None):
+ """
+ :: Experimental ::
+
+ If `observed` is Vector, conduct Pearson's chi-squared goodness
+ of fit test of the observed data against the expected distribution,
+ or againt the uniform distribution (by default), with each category
+ having an expected frequency of `1 / len(observed)`.
+ (Note: `observed` cannot contain negative values)
+
+ If `observed` is matrix, conduct Pearson's independence test on the
+ input contingency matrix, which cannot contain negative entries or
+ columns or rows that sum up to 0.
+
+ If `observed` is an RDD of LabeledPoint, conduct Pearson's independence
+ test for every feature against the label across the input RDD.
+ For each feature, the (feature, label) pairs are converted into a
+ contingency matrix for which the chi-squared statistic is computed.
+ All label and feature values must be categorical.
+
+ :param observed: it could be a vector containing the observed categorical
+ counts/relative frequencies, or the contingency matrix
+ (containing either counts or relative frequencies),
+ or an RDD of LabeledPoint containing the labeled dataset
+ with categorical features. Real-valued features will be
+ treated as categorical for each distinct value.
+ :param expected: Vector containing the expected categorical counts/relative
+ frequencies. `expected` is rescaled if the `expected` sum
+ differs from the `observed` sum.
+ :return: ChiSquaredTest object containing the test statistic, degrees
+ of freedom, p-value, the method used, and the null hypothesis.
+
+ >>> from pyspark.mllib.linalg import Vectors, Matrices
+ >>> observed = Vectors.dense([4, 6, 5])
+ >>> pearson = Statistics.chiSqTest(observed)
+ >>> print pearson.statistic
+ 0.4
+ >>> pearson.degreesOfFreedom
+ 2
+ >>> print round(pearson.pValue, 4)
+ 0.8187
+ >>> pearson.method
+ u'pearson'
+ >>> pearson.nullHypothesis
+ u'observed follows the same distribution as expected.'
+
+ >>> observed = Vectors.dense([21, 38, 43, 80])
+ >>> expected = Vectors.dense([3, 5, 7, 20])
+ >>> pearson = Statistics.chiSqTest(observed, expected)
+ >>> print round(pearson.pValue, 4)
+ 0.0027
+
+ >>> data = [40.0, 24.0, 29.0, 56.0, 32.0, 42.0, 31.0, 10.0, 0.0, 30.0, 15.0, 12.0]
+ >>> chi = Statistics.chiSqTest(Matrices.dense(3, 4, data))
+ >>> print round(chi.statistic, 4)
+ 21.9958
+
+ >>> data = [LabeledPoint(0.0, Vectors.dense([0.5, 10.0])),
+ ... LabeledPoint(0.0, Vectors.dense([1.5, 20.0])),
+ ... LabeledPoint(1.0, Vectors.dense([1.5, 30.0])),
+ ... LabeledPoint(0.0, Vectors.dense([3.5, 30.0])),
+ ... LabeledPoint(0.0, Vectors.dense([3.5, 40.0])),
+ ... LabeledPoint(1.0, Vectors.dense([3.5, 40.0])),]
+ >>> rdd = sc.parallelize(data, 4)
+ >>> chi = Statistics.chiSqTest(rdd)
+ >>> print chi[0].statistic
+ 0.75
+ >>> print chi[1].statistic
+ 1.5
+ """
+ if isinstance(observed, RDD):
+ if not isinstance(observed.first(), LabeledPoint):
+ raise ValueError("observed should be an RDD of LabeledPoint")
+ jmodels = callMLlibFunc("chiSqTest", observed)
+ return [ChiSqTestResult(m) for m in jmodels]
+
+ if isinstance(observed, Matrix):
+ jmodel = callMLlibFunc("chiSqTest", observed)
+ else:
+ if expected and len(expected) != len(observed):
+ raise ValueError("`expected` should have same length with `observed`")
+ jmodel = callMLlibFunc("chiSqTest", _convert_to_vector(observed), expected)
+ return ChiSqTestResult(jmodel)
+
def _test():
import doctest
diff --git a/python/pyspark/mllib/tests.py b/python/pyspark/mllib/tests.py
index d6fb87b378b4a..bc2ee5af496cf 100644
--- a/python/pyspark/mllib/tests.py
+++ b/python/pyspark/mllib/tests.py
@@ -33,14 +33,15 @@
else:
import unittest
-from pyspark.serializers import PickleSerializer
-from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, _convert_to_vector
+from pyspark.mllib.linalg import Vector, SparseVector, DenseVector, VectorUDT, _convert_to_vector,\
+ DenseMatrix
from pyspark.mllib.regression import LabeledPoint
from pyspark.mllib.random import RandomRDDs
from pyspark.mllib.stat import Statistics
+from pyspark.serializers import PickleSerializer
+from pyspark.sql import SQLContext
from pyspark.tests import ReusedPySparkTestCase as PySparkTestCase
-
_have_scipy = False
try:
import scipy.sparse
@@ -62,6 +63,7 @@ def _squared_distance(a, b):
class VectorTests(PySparkTestCase):
def _test_serialize(self, v):
+ self.assertEqual(v, ser.loads(ser.dumps(v)))
jvec = self.sc._jvm.SerDe.loads(bytearray(ser.dumps(v)))
nv = ser.loads(str(self.sc._jvm.SerDe.dumps(jvec)))
self.assertEqual(v, nv)
@@ -75,6 +77,8 @@ def test_serialize(self):
self._test_serialize(DenseVector(array([1., 2., 3., 4.])))
self._test_serialize(DenseVector(pyarray.array('d', range(10))))
self._test_serialize(SparseVector(4, {1: 1, 3: 2}))
+ self._test_serialize(SparseVector(3, {}))
+ self._test_serialize(DenseMatrix(2, 3, range(6)))
def test_dot(self):
sv = SparseVector(4, {1: 1, 3: 2})
@@ -105,6 +109,16 @@ def test_squared_distance(self):
self.assertEquals(0.0, _squared_distance(dv, dv))
self.assertEquals(0.0, _squared_distance(lst, lst))
+ def test_conversion(self):
+ # numpy arrays should be automatically upcast to float64
+ # tests for fix of [SPARK-5089]
+ v = array([1, 2, 3, 4], dtype='float64')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+ v = array([1, 2, 3, 4], dtype='float32')
+ dv = DenseVector(v)
+ self.assertTrue(dv.array.dtype == 'float64')
+
class ListTests(PySparkTestCase):
@@ -221,6 +235,39 @@ def test_col_with_different_rdds(self):
self.assertEqual(10, summary.count())
+class VectorUDTTests(PySparkTestCase):
+
+ dv0 = DenseVector([])
+ dv1 = DenseVector([1.0, 2.0])
+ sv0 = SparseVector(2, [], [])
+ sv1 = SparseVector(2, [1], [2.0])
+ udt = VectorUDT()
+
+ def test_json_schema(self):
+ self.assertEqual(VectorUDT.fromJson(self.udt.jsonValue()), self.udt)
+
+ def test_serialization(self):
+ for v in [self.dv0, self.dv1, self.sv0, self.sv1]:
+ self.assertEqual(v, self.udt.deserialize(self.udt.serialize(v)))
+
+ def test_infer_schema(self):
+ sqlCtx = SQLContext(self.sc)
+ rdd = self.sc.parallelize([LabeledPoint(1.0, self.dv1), LabeledPoint(0.0, self.sv1)])
+ srdd = sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "features"][0]
+ self.assertEqual(field.dataType, self.udt)
+ vectors = srdd.map(lambda p: p.features).collect()
+ self.assertEqual(len(vectors), 2)
+ for v in vectors:
+ if isinstance(v, SparseVector):
+ self.assertEqual(v, self.sv1)
+ elif isinstance(v, DenseVector):
+ self.assertEqual(v, self.dv1)
+ else:
+ raise ValueError("expecting a vector but got %r of type %r" % (v, type(v)))
+
+
@unittest.skipIf(not _have_scipy, "SciPy not installed")
class SciPyTests(PySparkTestCase):
diff --git a/python/pyspark/mllib/tree.py b/python/pyspark/mllib/tree.py
index 5d1a3c0962796..66702478474dc 100644
--- a/python/pyspark/mllib/tree.py
+++ b/python/pyspark/mllib/tree.py
@@ -15,12 +15,16 @@
# limitations under the License.
#
+from __future__ import absolute_import
+
+import random
+
from pyspark import SparkContext, RDD
from pyspark.mllib.common import callMLlibFunc, JavaModelWrapper
from pyspark.mllib.linalg import _convert_to_vector
from pyspark.mllib.regression import LabeledPoint
-__all__ = ['DecisionTreeModel', 'DecisionTree']
+__all__ = ['DecisionTreeModel', 'DecisionTree', 'RandomForestModel', 'RandomForest']
class DecisionTreeModel(JavaModelWrapper):
@@ -51,27 +55,25 @@ def depth(self):
return self._java_model.depth()
def __repr__(self):
- """ Print summary of model. """
+ """ summary of model. """
return self._java_model.toString()
def toDebugString(self):
- """ Print full model. """
+ """ full model. """
return self._java_model.toDebugString()
class DecisionTree(object):
"""
- Learning algorithm for a decision tree model
- for classification or regression.
+ Learning algorithm for a decision tree model for classification or regression.
EXPERIMENTAL: This is an experimental API.
- It will probably be modified for Spark v1.2.
-
+ It will probably be modified in future.
"""
- @staticmethod
- def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
+ @classmethod
+ def _train(cls, data, type, numClasses, features, impurity="gini", maxDepth=5, maxBins=32,
minInstancesPerNode=1, minInfoGain=0.0):
first = data.first()
assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
@@ -79,8 +81,8 @@ def _train(data, type, numClasses, features, impurity="gini", maxDepth=5, maxBin
impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
return DecisionTreeModel(model)
- @staticmethod
- def trainClassifier(data, numClasses, categoricalFeaturesInfo,
+ @classmethod
+ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo,
impurity="gini", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
@@ -98,8 +100,8 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
- :param minInstancesPerNode: Min number of instances required at child nodes to create
- the parent split
+ :param minInstancesPerNode: Min number of instances required at child
+ nodes to create the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
@@ -124,16 +126,19 @@ def trainClassifier(data, numClasses, categoricalFeaturesInfo,
Predict: 0.0
Else (feature 0 > 0.0)
Predict: 1.0
- >>> model.predict(array([1.0])) > 0
- True
- >>> model.predict(array([0.0])) == 0
- True
+ >>> model.predict(array([1.0]))
+ 1.0
+ >>> model.predict(array([0.0]))
+ 0.0
+ >>> rdd = sc.parallelize([[1.0], [0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
"""
- return DecisionTree._train(data, "classification", numClasses, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ return cls._train(data, "classification", numClasses, categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
- @staticmethod
- def trainRegressor(data, categoricalFeaturesInfo,
+ @classmethod
+ def trainRegressor(cls, data, categoricalFeaturesInfo,
impurity="variance", maxDepth=5, maxBins=32, minInstancesPerNode=1,
minInfoGain=0.0):
"""
@@ -150,14 +155,13 @@ def trainRegressor(data, categoricalFeaturesInfo,
E.g., depth 0 means 1 leaf node.
Depth 1 means 1 internal node + 2 leaf nodes.
:param maxBins: Number of bins used for finding splits at each node.
- :param minInstancesPerNode: Min number of instances required at child nodes to create
- the parent split
+ :param minInstancesPerNode: Min number of instances required at child
+ nodes to create the parent split
:param minInfoGain: Min info gain required to create a split
:return: DecisionTreeModel
Example usage:
- >>> from numpy import array
>>> from pyspark.mllib.regression import LabeledPoint
>>> from pyspark.mllib.tree import DecisionTree
>>> from pyspark.mllib.linalg import SparseVector
@@ -170,17 +174,213 @@ def trainRegressor(data, categoricalFeaturesInfo,
... ]
>>>
>>> model = DecisionTree.trainRegressor(sc.parallelize(sparse_data), {})
- >>> model.predict(array([0.0, 1.0])) == 1
- True
- >>> model.predict(array([0.0, 0.0])) == 0
- True
- >>> model.predict(SparseVector(2, {1: 1.0})) == 1
- True
- >>> model.predict(SparseVector(2, {1: 0.0})) == 0
- True
- """
- return DecisionTree._train(data, "regression", 0, categoricalFeaturesInfo,
- impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {1: 0.0}))
+ 0.0
+ >>> rdd = sc.parallelize([[0.0, 1.0], [0.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "regression", 0, categoricalFeaturesInfo,
+ impurity, maxDepth, maxBins, minInstancesPerNode, minInfoGain)
+
+
+class RandomForestModel(JavaModelWrapper):
+ """
+ Represents a random forest model.
+
+ EXPERIMENTAL: This is an experimental API.
+ It will probably be modified in future.
+ """
+ def predict(self, x):
+ """
+ Predict values for a single data point or an RDD of points using
+ the model trained.
+ """
+ if isinstance(x, RDD):
+ return self.call("predict", x.map(_convert_to_vector))
+
+ else:
+ return self.call("predict", _convert_to_vector(x))
+
+ def numTrees(self):
+ """
+ Get number of trees in forest.
+ """
+ return self.call("numTrees")
+
+ def totalNumNodes(self):
+ """
+ Get total number of nodes, summed over all trees in the forest.
+ """
+ return self.call("totalNumNodes")
+
+ def __repr__(self):
+ """ Summary of model """
+ return self._java_model.toString()
+
+ def toDebugString(self):
+ """ Full model """
+ return self._java_model.toDebugString()
+
+
+class RandomForest(object):
+ """
+ Learning algorithm for a random forest model for classification or regression.
+
+ EXPERIMENTAL: This is an experimental API.
+ It will probably be modified in future.
+ """
+
+ supportedFeatureSubsetStrategies = ("auto", "all", "sqrt", "log2", "onethird")
+
+ @classmethod
+ def _train(cls, data, algo, numClasses, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy, impurity, maxDepth, maxBins, seed):
+ first = data.first()
+ assert isinstance(first, LabeledPoint), "the data should be RDD of LabeledPoint"
+ if featureSubsetStrategy not in cls.supportedFeatureSubsetStrategies:
+ raise ValueError("unsupported featureSubsetStrategy: %s" % featureSubsetStrategy)
+ if seed is None:
+ seed = random.randint(0, 1 << 30)
+ model = callMLlibFunc("trainRandomForestModel", data, algo, numClasses,
+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
+ maxDepth, maxBins, seed)
+ return RandomForestModel(model)
+
+ @classmethod
+ def trainClassifier(cls, data, numClasses, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy="auto", impurity="gini", maxDepth=4, maxBins=32,
+ seed=None):
+ """
+ Method to train a decision tree model for binary or multiclass
+ classification.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels should take
+ values {0, 1, ..., numClasses-1}.
+ :param numClasses: number of classes for classification.
+ :param categoricalFeaturesInfo: Map storing arity of categorical features.
+ E.g., an entry (n -> k) indicates that feature n is categorical
+ with k categories indexed from 0: {0, 1, ..., k-1}.
+ :param numTrees: Number of trees in the random forest.
+ :param featureSubsetStrategy: Number of features to consider for splits at
+ each node.
+ Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ If "auto" is set, this parameter is set based on numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "sqrt".
+ :param impurity: Criterion used for information gain calculation.
+ Supported values: "gini" (recommended) or "entropy".
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1 leaf node;
+ depth 1 means 1 internal node + 2 leaf nodes. (default: 4)
+ :param maxBins: maximum number of bins used for splitting features
+ (default: 100)
+ :param seed: Random seed for bootstrapping and choosing feature subsets.
+ :return: RandomForestModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import RandomForest
+ >>>
+ >>> data = [
+ ... LabeledPoint(0.0, [0.0]),
+ ... LabeledPoint(0.0, [1.0]),
+ ... LabeledPoint(1.0, [2.0]),
+ ... LabeledPoint(1.0, [3.0])
+ ... ]
+ >>> model = RandomForest.trainClassifier(sc.parallelize(data), 2, {}, 3, seed=42)
+ >>> model.numTrees()
+ 3
+ >>> model.totalNumNodes()
+ 7
+ >>> print model,
+ TreeEnsembleModel classifier with 3 trees
+ >>> print model.toDebugString(),
+ TreeEnsembleModel classifier with 3 trees
+
+ Tree 0:
+ Predict: 1.0
+ Tree 1:
+ If (feature 0 <= 1.0)
+ Predict: 0.0
+ Else (feature 0 > 1.0)
+ Predict: 1.0
+ Tree 2:
+ If (feature 0 <= 1.0)
+ Predict: 0.0
+ Else (feature 0 > 1.0)
+ Predict: 1.0
+ >>> model.predict([2.0])
+ 1.0
+ >>> model.predict([0.0])
+ 0.0
+ >>> rdd = sc.parallelize([[3.0], [1.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.0]
+ """
+ return cls._train(data, "classification", numClasses,
+ categoricalFeaturesInfo, numTrees, featureSubsetStrategy, impurity,
+ maxDepth, maxBins, seed)
+
+ @classmethod
+ def trainRegressor(cls, data, categoricalFeaturesInfo, numTrees, featureSubsetStrategy="auto",
+ impurity="variance", maxDepth=4, maxBins=32, seed=None):
+ """
+ Method to train a decision tree model for regression.
+
+ :param data: Training dataset: RDD of LabeledPoint. Labels are
+ real numbers.
+ :param categoricalFeaturesInfo: Map storing arity of categorical
+ features. E.g., an entry (n -> k) indicates that feature
+ n is categorical with k categories indexed from 0:
+ {0, 1, ..., k-1}.
+ :param numTrees: Number of trees in the random forest.
+ :param featureSubsetStrategy: Number of features to consider for
+ splits at each node.
+ Supported: "auto" (default), "all", "sqrt", "log2", "onethird".
+ If "auto" is set, this parameter is set based on numTrees:
+ if numTrees == 1, set to "all";
+ if numTrees > 1 (forest) set to "onethird" for regression.
+ :param impurity: Criterion used for information gain calculation.
+ Supported values: "variance".
+ :param maxDepth: Maximum depth of the tree. E.g., depth 0 means 1
+ leaf node; depth 1 means 1 internal node + 2 leaf nodes.
+ (default: 4)
+ :param maxBins: maximum number of bins used for splitting features
+ (default: 100)
+ :param seed: Random seed for bootstrapping and choosing feature subsets.
+ :return: RandomForestModel that can be used for prediction
+
+ Example usage:
+
+ >>> from pyspark.mllib.regression import LabeledPoint
+ >>> from pyspark.mllib.tree import RandomForest
+ >>> from pyspark.mllib.linalg import SparseVector
+ >>>
+ >>> sparse_data = [
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 1.0})),
+ ... LabeledPoint(0.0, SparseVector(2, {0: 1.0})),
+ ... LabeledPoint(1.0, SparseVector(2, {1: 2.0}))
+ ... ]
+ >>>
+ >>> model = RandomForest.trainRegressor(sc.parallelize(sparse_data), {}, 2, seed=42)
+ >>> model.numTrees()
+ 2
+ >>> model.totalNumNodes()
+ 4
+ >>> model.predict(SparseVector(2, {1: 1.0}))
+ 1.0
+ >>> model.predict(SparseVector(2, {0: 1.0}))
+ 0.5
+ >>> rdd = sc.parallelize([[0.0, 1.0], [1.0, 0.0]])
+ >>> model.predict(rdd).collect()
+ [1.0, 0.5]
+ """
+ return cls._train(data, "regression", 0, categoricalFeaturesInfo, numTrees,
+ featureSubsetStrategy, impurity, maxDepth, maxBins, seed)
def _test():
diff --git a/python/pyspark/mllib/util.py b/python/pyspark/mllib/util.py
index 96aef8f510fa6..4ed978b45409c 100644
--- a/python/pyspark/mllib/util.py
+++ b/python/pyspark/mllib/util.py
@@ -161,15 +161,8 @@ def loadLabeledPoints(sc, path, minPartitions=None):
>>> tempFile = NamedTemporaryFile(delete=True)
>>> tempFile.close()
>>> sc.parallelize(examples, 1).saveAsTextFile(tempFile.name)
- >>> loaded = MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
- >>> type(loaded[0]) == LabeledPoint
- True
- >>> print examples[0]
- (1.1,(3,[0,2],[-1.23,4.56e-07]))
- >>> type(examples[1]) == LabeledPoint
- True
- >>> print examples[1]
- (0.0,[1.01,2.02,3.03])
+ >>> MLUtils.loadLabeledPoints(sc, tempFile.name).collect()
+ [LabeledPoint(1.1, (3,[0,2],[-1.23,4.56e-07])), LabeledPoint(0.0, [1.01,2.02,3.03])]
"""
minPartitions = minPartitions or min(sc.defaultParallelism, 2)
return callMLlibFunc("loadLabeledPoints", sc, path, minPartitions)
diff --git a/python/pyspark/rdd.py b/python/pyspark/rdd.py
index 550c9dd80522f..f8b5f18253328 100644
--- a/python/pyspark/rdd.py
+++ b/python/pyspark/rdd.py
@@ -28,7 +28,7 @@
import warnings
import heapq
import bisect
-from random import Random
+import random
from math import sqrt, log, isinf, isnan
from pyspark.accumulators import PStatsParam
@@ -38,7 +38,7 @@
from pyspark.join import python_join, python_left_outer_join, \
python_right_outer_join, python_full_outer_join, python_cogroup
from pyspark.statcounter import StatCounter
-from pyspark.rddsampler import RDDSampler, RDDStratifiedSampler
+from pyspark.rddsampler import RDDSampler, RDDRangeSampler, RDDStratifiedSampler
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
@@ -120,7 +120,7 @@ class RDD(object):
operated on in parallel.
"""
- def __init__(self, jrdd, ctx, jrdd_deserializer):
+ def __init__(self, jrdd, ctx, jrdd_deserializer=AutoBatchedSerializer(PickleSerializer())):
self._jrdd = jrdd
self.is_cached = False
self.is_checkpointed = False
@@ -129,12 +129,8 @@ def __init__(self, jrdd, ctx, jrdd_deserializer):
self._id = jrdd.id()
self._partitionFunc = None
- def _toPickleSerialization(self):
- if (self._jrdd_deserializer == PickleSerializer() or
- self._jrdd_deserializer == BatchedSerializer(PickleSerializer())):
- return self
- else:
- return self._reserialize(BatchedSerializer(PickleSerializer(), 10))
+ def _pickled(self):
+ return self._reserialize(AutoBatchedSerializer(PickleSerializer()))
def id(self):
"""
@@ -145,6 +141,17 @@ def id(self):
def __repr__(self):
return self._jrdd.toString()
+ def __getnewargs__(self):
+ # This method is called when attempting to pickle an RDD, which is always an error:
+ raise Exception(
+ "It appears that you are attempting to broadcast an RDD or reference an RDD from an "
+ "action or transformation. RDD transformations and actions can only be invoked by the "
+ "driver, not inside of other transformations; for example, "
+ "rdd1.map(lambda x: rdd2.values.count() * x) is invalid because the values "
+ "transformation and count action cannot be performed inside of the rdd1.map "
+ "transformation. For more information, see SPARK-5063."
+ )
+
@property
def context(self):
"""
@@ -314,20 +321,43 @@ def distinct(self, numPartitions=None):
def sample(self, withReplacement, fraction, seed=None):
"""
- Return a sampled subset of this RDD (relies on numpy and falls back
- on default random generator if numpy is unavailable).
+ Return a sampled subset of this RDD.
- >>> sc.parallelize(range(0, 100)).sample(False, 0.1, 2).collect() #doctest: +SKIP
- [2, 3, 20, 21, 24, 41, 42, 66, 67, 89, 90, 98]
+ >>> rdd = sc.parallelize(range(100), 4)
+ >>> rdd.sample(False, 0.1, 81).count()
+ 10
"""
assert fraction >= 0.0, "Negative fraction value: %s" % fraction
return self.mapPartitionsWithIndex(RDDSampler(withReplacement, fraction, seed).func, True)
+ def randomSplit(self, weights, seed=None):
+ """
+ Randomly splits this RDD with the provided weights.
+
+ :param weights: weights for splits, will be normalized if they don't sum to 1
+ :param seed: random seed
+ :return: split RDDs in a list
+
+ >>> rdd = sc.parallelize(range(5), 1)
+ >>> rdd1, rdd2 = rdd.randomSplit([2, 3], 17)
+ >>> rdd1.collect()
+ [1, 3]
+ >>> rdd2.collect()
+ [0, 2, 4]
+ """
+ s = float(sum(weights))
+ cweights = [0.0]
+ for w in weights:
+ cweights.append(cweights[-1] + w / s)
+ if seed is None:
+ seed = random.randint(0, 2 ** 32 - 1)
+ return [self.mapPartitionsWithIndex(RDDRangeSampler(lb, ub, seed).func, True)
+ for lb, ub in zip(cweights, cweights[1:])]
+
# this is ported from scala/spark/RDD.scala
def takeSample(self, withReplacement, num, seed=None):
"""
- Return a fixed-size sampled subset of this RDD (currently requires
- numpy).
+ Return a fixed-size sampled subset of this RDD.
>>> rdd = sc.parallelize(range(0, 10))
>>> len(rdd.takeSample(True, 20, 1))
@@ -348,7 +378,7 @@ def takeSample(self, withReplacement, num, seed=None):
if initialCount == 0:
return []
- rand = Random(seed)
+ rand = random.Random(seed)
if (not withReplacement) and num >= initialCount:
# shuffle current RDD and return
@@ -449,12 +479,10 @@ def intersection(self, other):
def _reserialize(self, serializer=None):
serializer = serializer or self.ctx.serializer
- if self._jrdd_deserializer == serializer:
- return self
- else:
- converted = self.map(lambda x: x, preservesPartitioning=True)
- converted._jrdd_deserializer = serializer
- return converted
+ if self._jrdd_deserializer != serializer:
+ self = self.map(lambda x: x, preservesPartitioning=True)
+ self._jrdd_deserializer = serializer
+ return self
def __add__(self, other):
"""
@@ -529,6 +557,8 @@ def sortPartition(iterator):
# the key-space into bins such that the bins have roughly the same
# number of (key, value) pairs falling into them
rddSize = self.count()
+ if not rddSize:
+ return self # empty RDD
maxSampleSize = numPartitions * 20.0 # constant from Spark's RangePartitioner
fraction = min(maxSampleSize / max(rddSize, 1), 1.0)
samples = self.sample(False, fraction, 1).map(lambda (k, v): k).collect()
@@ -1123,9 +1153,8 @@ def saveAsNewAPIHadoopDataset(self, conf, keyConverter=None, valueConverter=None
:param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
keyConverter, valueConverter, True)
def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1150,9 +1179,8 @@ def saveAsNewAPIHadoopFile(self, path, outputFormatClass, keyClass=None, valueCl
:param conf: Hadoop job configuration, passed in as a dict (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, batched, path,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsNewAPIHadoopFile(pickledRDD._jrdd, True, path,
outputFormatClass,
keyClass, valueClass,
keyConverter, valueConverter, jconf)
@@ -1169,9 +1197,8 @@ def saveAsHadoopDataset(self, conf, keyConverter=None, valueConverter=None):
:param valueConverter: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, batched, jconf,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopDataset(pickledRDD._jrdd, True, jconf,
keyConverter, valueConverter, False)
def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=None,
@@ -1198,9 +1225,8 @@ def saveAsHadoopFile(self, path, outputFormatClass, keyClass=None, valueClass=No
:param compressionCodecClass: (None by default)
"""
jconf = self.ctx._dictToJavaMap(conf)
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, batched, path,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsHadoopFile(pickledRDD._jrdd, True, path,
outputFormatClass,
keyClass, valueClass,
keyConverter, valueConverter,
@@ -1218,9 +1244,8 @@ def saveAsSequenceFile(self, path, compressionCodecClass=None):
:param path: path to sequence file
:param compressionCodecClass: (None by default)
"""
- pickledRDD = self._toPickleSerialization()
- batched = isinstance(pickledRDD._jrdd_deserializer, BatchedSerializer)
- self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, batched,
+ pickledRDD = self._pickled()
+ self.ctx._jvm.PythonRDD.saveAsSequenceFile(pickledRDD._jrdd, True,
path, compressionCodecClass)
def saveAsPickleFile(self, path, batchSize=10):
@@ -1235,8 +1260,11 @@ def saveAsPickleFile(self, path, batchSize=10):
>>> sorted(sc.pickleFile(tmpFile.name, 5).collect())
[1, 2, 'rdd', 'spark']
"""
- self._reserialize(BatchedSerializer(PickleSerializer(),
- batchSize))._jrdd.saveAsObjectFile(path)
+ if batchSize == 0:
+ ser = AutoBatchedSerializer(PickleSerializer())
+ else:
+ ser = BatchedSerializer(PickleSerializer(), batchSize)
+ self._reserialize(ser)._jrdd.saveAsObjectFile(path)
def saveAsTextFile(self, path):
"""
@@ -1777,28 +1805,27 @@ def zip(self, other):
>>> x.zip(y).collect()
[(0, 1000), (1, 1001), (2, 1002), (3, 1003), (4, 1004)]
"""
- if self.getNumPartitions() != other.getNumPartitions():
- raise ValueError("Can only zip with RDD which has the same number of partitions")
-
def get_batch_size(ser):
if isinstance(ser, BatchedSerializer):
return ser.batchSize
- return 0
+ return 1 # not batched
def batch_as(rdd, batchSize):
- ser = rdd._jrdd_deserializer
- if isinstance(ser, BatchedSerializer):
- ser = ser.serializer
- return rdd._reserialize(BatchedSerializer(ser, batchSize))
+ return rdd._reserialize(BatchedSerializer(PickleSerializer(), batchSize))
my_batch = get_batch_size(self._jrdd_deserializer)
other_batch = get_batch_size(other._jrdd_deserializer)
if my_batch != other_batch:
- # use the greatest batchSize to batch the other one.
- if my_batch > other_batch:
- other = batch_as(other, my_batch)
- else:
- self = batch_as(self, other_batch)
+ # use the smallest batchSize for both of them
+ batchSize = min(my_batch, other_batch)
+ if batchSize <= 0:
+ # auto batched or unlimited
+ batchSize = 100
+ other = batch_as(other, batchSize)
+ self = batch_as(self, batchSize)
+
+ if self.getNumPartitions() != other.getNumPartitions():
+ raise ValueError("Can only zip with RDD which has the same number of partitions")
# There will be an Exception in JVM if there are different number
# of items in each partitions.
@@ -1937,25 +1964,14 @@ def lookup(self, key):
return values.collect()
- def _is_pickled(self):
- """ Return this RDD is serialized by Pickle or not. """
- der = self._jrdd_deserializer
- if isinstance(der, PickleSerializer):
- return True
- if isinstance(der, BatchedSerializer) and isinstance(der.serializer, PickleSerializer):
- return True
- return False
-
def _to_java_object_rdd(self):
""" Return an JavaRDD of Object by unpickling
It will convert each Python object into Java object by Pyrolite, whenever the
RDD is serialized in batch or not.
"""
- rdd = self._reserialize(AutoBatchedSerializer(PickleSerializer())) \
- if not self._is_pickled() else self
- is_batch = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
- return self.ctx._jvm.PythonRDD.pythonToJava(rdd._jrdd, is_batch)
+ rdd = self._pickled()
+ return self.ctx._jvm.SerDeUtil.pythonToJava(rdd._jrdd, True)
def countApprox(self, timeout, confidence=0.95):
"""
@@ -2135,7 +2151,7 @@ def _test():
globs = globals().copy()
# The small batch size here ensures that we see multiple batches,
# even in these small test examples:
- globs['sc'] = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ globs['sc'] = SparkContext('local[4]', 'PythonTest')
(failure_count, test_count) = doctest.testmod(
globs=globs, optionflags=doctest.ELLIPSIS)
globs['sc'].stop()
diff --git a/python/pyspark/rddsampler.py b/python/pyspark/rddsampler.py
index 528a181e8905a..459e1427803cb 100644
--- a/python/pyspark/rddsampler.py
+++ b/python/pyspark/rddsampler.py
@@ -17,82 +17,48 @@
import sys
import random
+import math
class RDDSamplerBase(object):
def __init__(self, withReplacement, seed=None):
- try:
- import numpy
- self._use_numpy = True
- except ImportError:
- print >> sys.stderr, (
- "NumPy does not appear to be installed. "
- "Falling back to default random generator for sampling.")
- self._use_numpy = False
-
- self._seed = seed if seed is not None else random.randint(0, 2 ** 32 - 1)
+ self._seed = seed if seed is not None else random.randint(0, sys.maxint)
self._withReplacement = withReplacement
self._random = None
- self._split = None
- self._rand_initialized = False
def initRandomGenerator(self, split):
- if self._use_numpy:
- import numpy
- self._random = numpy.random.RandomState(self._seed)
+ self._random = random.Random(self._seed ^ split)
+
+ # mixing because the initial seeds are close to each other
+ for _ in xrange(10):
+ self._random.randint(0, 1)
+
+ def getUniformSample(self):
+ return self._random.random()
+
+ def getPoissonSample(self, mean):
+ # Using Knuth's algorithm described in
+ # http://en.wikipedia.org/wiki/Poisson_distribution
+ if mean < 20.0:
+ # one exp and k+1 random calls
+ l = math.exp(-mean)
+ p = self._random.random()
+ k = 0
+ while p > l:
+ k += 1
+ p *= self._random.random()
else:
- self._random = random.Random(self._seed)
+ # switch to the log domain, k+1 expovariate (random + log) calls
+ p = self._random.expovariate(mean)
+ k = 0
+ while p < 1.0:
+ k += 1
+ p += self._random.expovariate(mean)
+ return k
- for _ in range(0, split):
- # discard the next few values in the sequence to have a
- # different seed for the different splits
- self._random.randint(0, 2 ** 32 - 1)
-
- self._split = split
- self._rand_initialized = True
-
- def getUniformSample(self, split):
- if not self._rand_initialized or split != self._split:
- self.initRandomGenerator(split)
-
- if self._use_numpy:
- return self._random.random_sample()
- else:
- return self._random.uniform(0.0, 1.0)
-
- def getPoissonSample(self, split, mean):
- if not self._rand_initialized or split != self._split:
- self.initRandomGenerator(split)
-
- if self._use_numpy:
- return self._random.poisson(mean)
- else:
- # here we simulate drawing numbers n_i ~ Poisson(lambda = 1/mean) by
- # drawing a sequence of numbers delta_j ~ Exp(mean)
- num_arrivals = 1
- cur_time = 0.0
-
- cur_time += self._random.expovariate(mean)
-
- if cur_time > 1.0:
- return 0
-
- while(cur_time <= 1.0):
- cur_time += self._random.expovariate(mean)
- num_arrivals += 1
-
- return (num_arrivals - 1)
-
- def shuffle(self, vals):
- if self._random is None:
- self.initRandomGenerator(0) # this should only ever called on the master so
- # the split does not matter
-
- if self._use_numpy:
- self._random.shuffle(vals)
- else:
- self._random.shuffle(vals, self._random.random)
+ def func(self, split, iterator):
+ raise NotImplementedError
class RDDSampler(RDDSamplerBase):
@@ -102,20 +68,35 @@ def __init__(self, withReplacement, fraction, seed=None):
self._fraction = fraction
def func(self, split, iterator):
+ self.initRandomGenerator(split)
if self._withReplacement:
for obj in iterator:
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
- count = self.getPoissonSample(split, mean=self._fraction)
+ count = self.getPoissonSample(self._fraction)
for _ in range(0, count):
yield obj
else:
for obj in iterator:
- if self.getUniformSample(split) <= self._fraction:
+ if self.getUniformSample() < self._fraction:
yield obj
+class RDDRangeSampler(RDDSamplerBase):
+
+ def __init__(self, lowerBound, upperBound, seed=None):
+ RDDSamplerBase.__init__(self, False, seed)
+ self._lowerBound = lowerBound
+ self._upperBound = upperBound
+
+ def func(self, split, iterator):
+ self.initRandomGenerator(split)
+ for obj in iterator:
+ if self._lowerBound <= self.getUniformSample() < self._upperBound:
+ yield obj
+
+
class RDDStratifiedSampler(RDDSamplerBase):
def __init__(self, withReplacement, fractions, seed=None):
@@ -123,15 +104,16 @@ def __init__(self, withReplacement, fractions, seed=None):
self._fractions = fractions
def func(self, split, iterator):
+ self.initRandomGenerator(split)
if self._withReplacement:
for key, val in iterator:
# For large datasets, the expected number of occurrences of each element in
# a sample with replacement is Poisson(frac). We use that to get a count for
# each element.
- count = self.getPoissonSample(split, mean=self._fractions[key])
+ count = self.getPoissonSample(self._fractions[key])
for _ in range(0, count):
yield key, val
else:
for key, val in iterator:
- if self.getUniformSample(split) <= self._fractions[key]:
+ if self.getUniformSample() < self._fractions[key]:
yield key, val
diff --git a/python/pyspark/serializers.py b/python/pyspark/serializers.py
index 904bd9f2652d3..b8bda835174b2 100644
--- a/python/pyspark/serializers.py
+++ b/python/pyspark/serializers.py
@@ -33,9 +33,8 @@
[0, 2, 4, 6, 8, 10, 12, 14, 16, 18]
>>> sc.stop()
-By default, PySpark serialize objects in batches; the batch size can be
-controlled through SparkContext's C{batchSize} parameter
-(the default size is 1024 objects):
+PySpark serialize objects in batches; By default, the batch size is chosen based
+on the size of objects, also configurable by SparkContext's C{batchSize} parameter:
>>> sc = SparkContext('local', 'test', batchSize=2)
>>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
@@ -48,16 +47,6 @@
>>> rdd._jrdd.count()
8L
>>> sc.stop()
-
-A batch size of -1 uses an unlimited batch size, and a size of 1 disables
-batching:
-
->>> sc = SparkContext('local', 'test', batchSize=1)
->>> rdd = sc.parallelize(range(16), 4).map(lambda x: x)
->>> rdd.glom().collect()
-[[0, 1, 2, 3], [4, 5, 6, 7], [8, 9, 10, 11], [12, 13, 14, 15]]
->>> rdd._jrdd.count()
-16L
"""
import cPickle
@@ -73,7 +62,7 @@
from pyspark import cloudpickle
-__all__ = ["PickleSerializer", "MarshalSerializer"]
+__all__ = ["PickleSerializer", "MarshalSerializer", "UTF8Deserializer"]
class SpecialLengths(object):
@@ -113,7 +102,7 @@ def __ne__(self, other):
return not self.__eq__(other)
def __repr__(self):
- return "<%s object>" % self.__class__.__name__
+ return "%s()" % self.__class__.__name__
def __hash__(self):
return hash(str(self))
@@ -144,6 +133,8 @@ def load_stream(self, stream):
def _write_with_length(self, obj, stream):
serialized = self.dumps(obj)
+ if len(serialized) > (1 << 31):
+ raise ValueError("can not serialize object larger than 2G")
write_int(len(serialized), stream)
if self._only_write_strings:
stream.write(str(serialized))
@@ -181,6 +172,7 @@ class BatchedSerializer(Serializer):
"""
UNLIMITED_BATCH_SIZE = -1
+ UNKNOWN_BATCH_SIZE = 0
def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
self.serializer = serializer
@@ -189,6 +181,10 @@ def __init__(self, serializer, batchSize=UNLIMITED_BATCH_SIZE):
def _batched(self, iterator):
if self.batchSize == self.UNLIMITED_BATCH_SIZE:
yield list(iterator)
+ elif hasattr(iterator, "__len__") and hasattr(iterator, "__getslice__"):
+ n = len(iterator)
+ for i in xrange(0, n, self.batchSize):
+ yield iterator[i: i + self.batchSize]
else:
items = []
count = 0
@@ -213,10 +209,10 @@ def _load_stream_without_unbatching(self, stream):
def __eq__(self, other):
return (isinstance(other, BatchedSerializer) and
- other.serializer == self.serializer)
+ other.serializer == self.serializer and other.batchSize == self.batchSize)
def __repr__(self):
- return "BatchedSerializer<%s>" % str(self.serializer)
+ return "BatchedSerializer(%s, %d)" % (str(self.serializer), self.batchSize)
class AutoBatchedSerializer(BatchedSerializer):
@@ -225,7 +221,7 @@ class AutoBatchedSerializer(BatchedSerializer):
"""
def __init__(self, serializer, bestSize=1 << 16):
- BatchedSerializer.__init__(self, serializer, -1)
+ BatchedSerializer.__init__(self, serializer, self.UNKNOWN_BATCH_SIZE)
self.bestSize = bestSize
def dump_stream(self, iterator, stream):
@@ -248,10 +244,10 @@ def dump_stream(self, iterator, stream):
def __eq__(self, other):
return (isinstance(other, AutoBatchedSerializer) and
- other.serializer == self.serializer)
+ other.serializer == self.serializer and other.bestSize == self.bestSize)
def __str__(self):
- return "AutoBatchedSerializer<%s>" % str(self.serializer)
+ return "AutoBatchedSerializer(%s)" % str(self.serializer)
class CartesianDeserializer(FramedSerializer):
@@ -284,7 +280,7 @@ def __eq__(self, other):
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
- return "CartesianDeserializer<%s, %s>" % \
+ return "CartesianDeserializer(%s, %s)" % \
(str(self.key_ser), str(self.val_ser))
@@ -311,7 +307,7 @@ def __eq__(self, other):
self.key_ser == other.key_ser and self.val_ser == other.val_ser)
def __repr__(self):
- return "PairDeserializer<%s, %s>" % (str(self.key_ser), str(self.val_ser))
+ return "PairDeserializer(%s, %s)" % (str(self.key_ser), str(self.val_ser))
class NoOpSerializer(FramedSerializer):
@@ -430,7 +426,7 @@ def loads(self, obj):
class AutoSerializer(FramedSerializer):
"""
- Choose marshal or cPickle as serialization protocol autumatically
+ Choose marshal or cPickle as serialization protocol automatically
"""
def __init__(self):
@@ -460,9 +456,9 @@ class CompressedSerializer(FramedSerializer):
"""
Compress the serialized data
"""
-
def __init__(self, serializer):
FramedSerializer.__init__(self)
+ assert isinstance(serializer, FramedSerializer), "serializer must be a FramedSerializer"
self.serializer = serializer
def dumps(self, obj):
@@ -471,6 +467,9 @@ def dumps(self, obj):
def loads(self, obj):
return self.serializer.loads(zlib.decompress(obj))
+ def __eq__(self, other):
+ return isinstance(other, CompressedSerializer) and self.serializer == other.serializer
+
class UTF8Deserializer(Serializer):
@@ -497,6 +496,9 @@ def load_stream(self, stream):
except EOFError:
return
+ def __eq__(self, other):
+ return isinstance(other, UTF8Deserializer) and self.use_unicode == other.use_unicode
+
def read_long(stream):
length = stream.read(8)
@@ -527,3 +529,8 @@ def write_int(value, stream):
def write_with_length(obj, stream):
write_int(len(obj), stream)
stream.write(obj)
+
+
+if __name__ == '__main__':
+ import doctest
+ doctest.testmod()
diff --git a/python/pyspark/shuffle.py b/python/pyspark/shuffle.py
index d57a802e4734a..10a7ccd502000 100644
--- a/python/pyspark/shuffle.py
+++ b/python/pyspark/shuffle.py
@@ -25,7 +25,7 @@
import random
import pyspark.heapq3 as heapq
-from pyspark.serializers import BatchedSerializer, PickleSerializer
+from pyspark.serializers import AutoBatchedSerializer, PickleSerializer
try:
import psutil
@@ -213,8 +213,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
Merger.__init__(self, aggregator)
self.memory_limit = memory_limit
# default serializer is only used for tests
- self.serializer = serializer or \
- BatchedSerializer(PickleSerializer(), 1024)
+ self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
@@ -470,7 +469,7 @@ class ExternalSorter(object):
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
- self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)
+ self.serializer = serializer or AutoBatchedSerializer(PickleSerializer())
def _get_path(self, n):
""" Choose one directory for spill by number n """
@@ -479,13 +478,21 @@ def _get_path(self, n):
os.makedirs(d)
return os.path.join(d, str(n))
+ def _next_limit(self):
+ """
+ Return the next memory limit. If the memory is not released
+ after spilling, it will dump the data only when the used memory
+ starts to increase.
+ """
+ return max(self.memory_limit, get_used_memory() * 1.05)
+
def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
global MemoryBytesSpilled, DiskBytesSpilled
- batch = 100
+ batch, limit = 100, self._next_limit()
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
@@ -505,6 +512,7 @@ def sorted(self, iterator, key=None, reverse=False):
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []
gc.collect()
+ limit = self._next_limit()
MemoryBytesSpilled += (used_memory - get_used_memory()) << 20
DiskBytesSpilled += os.path.getsize(path)
diff --git a/python/pyspark/sql.py b/python/pyspark/sql.py
index 98e41f8575679..ae288471b0e51 100644
--- a/python/pyspark/sql.py
+++ b/python/pyspark/sql.py
@@ -44,7 +44,8 @@
from py4j.java_collections import ListConverter, MapConverter
from pyspark.rdd import RDD
-from pyspark.serializers import BatchedSerializer, PickleSerializer, CloudPickleSerializer
+from pyspark.serializers import BatchedSerializer, AutoBatchedSerializer, PickleSerializer, \
+ CloudPickleSerializer, UTF8Deserializer
from pyspark.storagelevel import StorageLevel
from pyspark.traceback_utils import SCCallSiteSync
@@ -109,6 +110,15 @@ def __eq__(self, other):
return self is other
+class NullType(PrimitiveType):
+
+ """Spark SQL NullType
+
+ The data type representing None, used for the types which has not
+ been inferred.
+ """
+
+
class StringType(PrimitiveType):
"""Spark SQL StringType
@@ -331,7 +341,7 @@ class StructField(DataType):
"""
- def __init__(self, name, dataType, nullable, metadata=None):
+ def __init__(self, name, dataType, nullable=True, metadata=None):
"""Creates a StructField
:param name: the name of this field.
:param dataType: the data type of this field.
@@ -408,6 +418,75 @@ def fromJson(cls, json):
return StructType([StructField.fromJson(f) for f in json["fields"]])
+class UserDefinedType(DataType):
+ """
+ :: WARN: Spark Internal Use Only ::
+ SQL User-Defined Type (UDT).
+ """
+
+ @classmethod
+ def typeName(cls):
+ return cls.__name__.lower()
+
+ @classmethod
+ def sqlType(cls):
+ """
+ Underlying SQL storage type for this UDT.
+ """
+ raise NotImplementedError("UDT must implement sqlType().")
+
+ @classmethod
+ def module(cls):
+ """
+ The Python module of the UDT.
+ """
+ raise NotImplementedError("UDT must implement module().")
+
+ @classmethod
+ def scalaUDT(cls):
+ """
+ The class name of the paired Scala UDT.
+ """
+ raise NotImplementedError("UDT must have a paired Scala UDT.")
+
+ def serialize(self, obj):
+ """
+ Converts the a user-type object into a SQL datum.
+ """
+ raise NotImplementedError("UDT must implement serialize().")
+
+ def deserialize(self, datum):
+ """
+ Converts a SQL datum into a user-type object.
+ """
+ raise NotImplementedError("UDT must implement deserialize().")
+
+ def json(self):
+ return json.dumps(self.jsonValue(), separators=(',', ':'), sort_keys=True)
+
+ def jsonValue(self):
+ schema = {
+ "type": "udt",
+ "class": self.scalaUDT(),
+ "pyClass": "%s.%s" % (self.module(), type(self).__name__),
+ "sqlType": self.sqlType().jsonValue()
+ }
+ return schema
+
+ @classmethod
+ def fromJson(cls, json):
+ pyUDT = json["pyClass"]
+ split = pyUDT.rfind(".")
+ pyModule = pyUDT[:split]
+ pyClass = pyUDT[split+1:]
+ m = __import__(pyModule, globals(), locals(), [pyClass], -1)
+ UDT = getattr(m, pyClass)
+ return UDT()
+
+ def __eq__(self, other):
+ return type(self) == type(other)
+
+
_all_primitive_types = dict((v.typeName(), v)
for v in globals().itervalues()
if type(v) is PrimitiveTypeSingleton and
@@ -460,6 +539,12 @@ def _parse_datatype_json_string(json_string):
... complex_arraytype, False)
>>> check_datatype(complex_maptype)
True
+ >>> check_datatype(ExamplePointUDT())
+ True
+ >>> structtype_with_udt = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> check_datatype(structtype_with_udt)
+ True
"""
return _parse_datatype_json_value(json.loads(json_string))
@@ -479,11 +564,18 @@ def _parse_datatype_json_value(json_value):
else:
raise ValueError("Could not parse datatype: %s" % json_value)
else:
- return _all_complex_types[json_value["type"]].fromJson(json_value)
+ tpe = json_value["type"]
+ if tpe in _all_complex_types:
+ return _all_complex_types[tpe].fromJson(json_value)
+ elif tpe == 'udt':
+ return UserDefinedType.fromJson(json_value)
+ else:
+ raise ValueError("not supported type: %s" % tpe)
# Mapping Python types to Spark SQL DataType
_type_mappings = {
+ type(None): NullType,
bool: BooleanType,
int: IntegerType,
long: LongType,
@@ -499,23 +591,34 @@ def _parse_datatype_json_value(json_value):
def _infer_type(obj):
- """Infer the DataType from obj"""
+ """Infer the DataType from obj
+
+ >>> p = ExamplePoint(1.0, 2.0)
+ >>> _infer_type(p)
+ ExamplePointUDT
+ """
if obj is None:
raise ValueError("Can not infer type for None")
+ if hasattr(obj, '__UDT__'):
+ return obj.__UDT__
+
dataType = _type_mappings.get(type(obj))
if dataType is not None:
return dataType()
if isinstance(obj, dict):
- if not obj:
- raise ValueError("Can not infer type for empty dict")
- key, value = obj.iteritems().next()
- return MapType(_infer_type(key), _infer_type(value), True)
+ for key, value in obj.iteritems():
+ if key is not None and value is not None:
+ return MapType(_infer_type(key), _infer_type(value), True)
+ else:
+ return MapType(NullType(), NullType(), True)
elif isinstance(obj, (list, array)):
- if not obj:
- raise ValueError("Can not infer type for empty list/array")
- return ArrayType(_infer_type(obj[0]), True)
+ for v in obj:
+ if v is not None:
+ return ArrayType(_infer_type(obj[0]), True)
+ else:
+ return ArrayType(NullType(), True)
else:
try:
return _infer_schema(obj)
@@ -548,60 +651,180 @@ def _infer_schema(row):
return StructType(fields)
-def _create_converter(obj, dataType):
+def _need_python_to_sql_conversion(dataType):
+ """
+ Checks whether we need python to sql conversion for the given type.
+ For now, only UDTs need this conversion.
+
+ >>> _need_python_to_sql_conversion(DoubleType())
+ False
+ >>> schema0 = StructType([StructField("indices", ArrayType(IntegerType(), False), False),
+ ... StructField("values", ArrayType(DoubleType(), False), False)])
+ >>> _need_python_to_sql_conversion(schema0)
+ False
+ >>> _need_python_to_sql_conversion(ExamplePointUDT())
+ True
+ >>> schema1 = ArrayType(ExamplePointUDT(), False)
+ >>> _need_python_to_sql_conversion(schema1)
+ True
+ >>> schema2 = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> _need_python_to_sql_conversion(schema2)
+ True
+ """
+ if isinstance(dataType, StructType):
+ return any([_need_python_to_sql_conversion(f.dataType) for f in dataType.fields])
+ elif isinstance(dataType, ArrayType):
+ return _need_python_to_sql_conversion(dataType.elementType)
+ elif isinstance(dataType, MapType):
+ return _need_python_to_sql_conversion(dataType.keyType) or \
+ _need_python_to_sql_conversion(dataType.valueType)
+ elif isinstance(dataType, UserDefinedType):
+ return True
+ else:
+ return False
+
+
+def _python_to_sql_converter(dataType):
+ """
+ Returns a converter that converts a Python object into a SQL datum for the given type.
+
+ >>> conv = _python_to_sql_converter(DoubleType())
+ >>> conv(1.0)
+ 1.0
+ >>> conv = _python_to_sql_converter(ArrayType(DoubleType(), False))
+ >>> conv([1.0, 2.0])
+ [1.0, 2.0]
+ >>> conv = _python_to_sql_converter(ExamplePointUDT())
+ >>> conv(ExamplePoint(1.0, 2.0))
+ [1.0, 2.0]
+ >>> schema = StructType([StructField("label", DoubleType(), False),
+ ... StructField("point", ExamplePointUDT(), False)])
+ >>> conv = _python_to_sql_converter(schema)
+ >>> conv((1.0, ExamplePoint(1.0, 2.0)))
+ (1.0, [1.0, 2.0])
+ """
+ if not _need_python_to_sql_conversion(dataType):
+ return lambda x: x
+
+ if isinstance(dataType, StructType):
+ names, types = zip(*[(f.name, f.dataType) for f in dataType.fields])
+ converters = map(_python_to_sql_converter, types)
+
+ def converter(obj):
+ if isinstance(obj, dict):
+ return tuple(c(obj.get(n)) for n, c in zip(names, converters))
+ elif isinstance(obj, tuple):
+ if hasattr(obj, "_fields") or hasattr(obj, "__FIELDS__"):
+ return tuple(c(v) for c, v in zip(converters, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj): # k-v pairs
+ d = dict(obj)
+ return tuple(c(d.get(n)) for n, c in zip(names, converters))
+ else:
+ return tuple(c(v) for c, v in zip(converters, obj))
+ else:
+ raise ValueError("Unexpected tuple %r with type %r" % (obj, dataType))
+ return converter
+ elif isinstance(dataType, ArrayType):
+ element_converter = _python_to_sql_converter(dataType.elementType)
+ return lambda a: [element_converter(v) for v in a]
+ elif isinstance(dataType, MapType):
+ key_converter = _python_to_sql_converter(dataType.keyType)
+ value_converter = _python_to_sql_converter(dataType.valueType)
+ return lambda m: dict([(key_converter(k), value_converter(v)) for k, v in m.items()])
+ elif isinstance(dataType, UserDefinedType):
+ return lambda obj: dataType.serialize(obj)
+ else:
+ raise ValueError("Unexpected type %r" % dataType)
+
+
+def _has_nulltype(dt):
+ """ Return whether there is NullType in `dt` or not """
+ if isinstance(dt, StructType):
+ return any(_has_nulltype(f.dataType) for f in dt.fields)
+ elif isinstance(dt, ArrayType):
+ return _has_nulltype((dt.elementType))
+ elif isinstance(dt, MapType):
+ return _has_nulltype(dt.keyType) or _has_nulltype(dt.valueType)
+ else:
+ return isinstance(dt, NullType)
+
+
+def _merge_type(a, b):
+ if isinstance(a, NullType):
+ return b
+ elif isinstance(b, NullType):
+ return a
+ elif type(a) is not type(b):
+ # TODO: type cast (such as int -> long)
+ raise TypeError("Can not merge type %s and %s" % (a, b))
+
+ # same type
+ if isinstance(a, StructType):
+ nfs = dict((f.name, f.dataType) for f in b.fields)
+ fields = [StructField(f.name, _merge_type(f.dataType, nfs.get(f.name, NullType())))
+ for f in a.fields]
+ names = set([f.name for f in fields])
+ for n in nfs:
+ if n not in names:
+ fields.append(StructField(n, nfs[n]))
+ return StructType(fields)
+
+ elif isinstance(a, ArrayType):
+ return ArrayType(_merge_type(a.elementType, b.elementType), True)
+
+ elif isinstance(a, MapType):
+ return MapType(_merge_type(a.keyType, b.keyType),
+ _merge_type(a.valueType, b.valueType),
+ True)
+ else:
+ return a
+
+
+def _create_converter(dataType):
"""Create an converter to drop the names of fields in obj """
if isinstance(dataType, ArrayType):
- conv = _create_converter(obj[0], dataType.elementType)
+ conv = _create_converter(dataType.elementType)
return lambda row: map(conv, row)
elif isinstance(dataType, MapType):
- value = obj.values()[0]
- conv = _create_converter(value, dataType.valueType)
+ conv = _create_converter(dataType.valueType)
return lambda row: dict((k, conv(v)) for k, v in row.iteritems())
+ elif isinstance(dataType, NullType):
+ return lambda x: None
+
elif not isinstance(dataType, StructType):
return lambda x: x
# dataType must be StructType
names = [f.name for f in dataType.fields]
+ converters = [_create_converter(f.dataType) for f in dataType.fields]
+
+ def convert_struct(obj):
+ if obj is None:
+ return
+
+ if isinstance(obj, tuple):
+ if hasattr(obj, "fields"):
+ d = dict(zip(obj.fields, obj))
+ if hasattr(obj, "__FIELDS__"):
+ d = dict(zip(obj.__FIELDS__, obj))
+ elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
+ d = dict(obj)
+ else:
+ raise ValueError("unexpected tuple: %s" % obj)
- if isinstance(obj, dict):
- conv = lambda o: tuple(o.get(n) for n in names)
-
- elif isinstance(obj, tuple):
- if hasattr(obj, "_fields"): # namedtuple
- conv = tuple
- elif hasattr(obj, "__FIELDS__"):
- conv = tuple
- elif all(isinstance(x, tuple) and len(x) == 2 for x in obj):
- conv = lambda o: tuple(v for k, v in o)
+ elif isinstance(obj, dict):
+ d = obj
+ elif hasattr(obj, "__dict__"): # object
+ d = obj.__dict__
else:
- raise ValueError("unexpected tuple")
-
- elif hasattr(obj, "__dict__"): # object
- conv = lambda o: [o.__dict__.get(n, None) for n in names]
-
- if all(isinstance(f.dataType, PrimitiveType) for f in dataType.fields):
- return conv
-
- row = conv(obj)
- convs = [_create_converter(v, f.dataType)
- for v, f in zip(row, dataType.fields)]
-
- def nested_conv(row):
- return tuple(f(v) for f, v in zip(convs, conv(row)))
+ raise ValueError("Unexpected obj: %s" % obj)
- return nested_conv
+ return tuple([conv(d.get(name)) for name, conv in zip(names, converters)])
-
-def _drop_schema(rows, schema):
- """ all the names of fields, becoming tuples"""
- iterator = iter(rows)
- row = iterator.next()
- converter = _create_converter(row, schema)
- yield converter(row)
- for i in iterator:
- yield converter(i)
+ return convert_struct
_BRACKETS = {'(': ')', '[': ']', '{': '}'}
@@ -713,7 +936,7 @@ def _infer_schema_type(obj, dataType):
return _infer_type(obj)
if not obj:
- raise ValueError("Can not infer type from empty value")
+ return NullType()
if isinstance(dataType, ArrayType):
eType = _infer_schema_type(obj[0], dataType.elementType)
@@ -775,11 +998,22 @@ def _verify_type(obj, dataType):
Traceback (most recent call last):
...
ValueError:...
+ >>> _verify_type(ExamplePoint(1.0, 2.0), ExamplePointUDT())
+ >>> _verify_type([1.0, 2.0], ExamplePointUDT()) # doctest: +IGNORE_EXCEPTION_DETAIL
+ Traceback (most recent call last):
+ ...
+ ValueError:...
"""
# all objects are nullable
if obj is None:
return
+ if isinstance(dataType, UserDefinedType):
+ if not (hasattr(obj, '__UDT__') and obj.__UDT__ == dataType):
+ raise ValueError("%r is not an instance of type %r" % (obj, dataType))
+ _verify_type(dataType.serialize(obj), dataType.sqlType())
+ return
+
_type = type(dataType)
assert _type in _acceptable_types, "unkown datatype: %s" % dataType
@@ -854,6 +1088,8 @@ def _has_struct_or_date(dt):
return _has_struct_or_date(dt.valueType)
elif isinstance(dt, DateType):
return True
+ elif isinstance(dt, UserDefinedType):
+ return True
return False
@@ -924,6 +1160,9 @@ def Dict(d):
elif isinstance(dataType, DateType):
return datetime.date
+ elif isinstance(dataType, UserDefinedType):
+ return lambda datum: dataType.deserialize(datum)
+
elif not isinstance(dataType, StructType):
raise Exception("unexpected data type: %s" % dataType)
@@ -939,7 +1178,7 @@ class Row(tuple):
def asDict(self):
""" Return as a dict """
- return dict(zip(self.__FIELDS__, self))
+ return dict((n, getattr(self, n)) for n in self.__FIELDS__)
def __repr__(self):
# call collect __repr__ for nested objects
@@ -995,7 +1234,6 @@ def __init__(self, sparkContext, sqlContext=None):
self._sc = sparkContext
self._jsc = self._sc._jsc
self._jvm = self._sc._jvm
- self._pythonToJava = self._jvm.PythonRDD.pythonToJavaArray
self._scala_SQLContext = sqlContext
@property
@@ -1025,8 +1263,8 @@ def registerFunction(self, name, f, returnType=StringType()):
"""
func = lambda _, it: imap(lambda x: f(*x), it)
command = (func, None,
- BatchedSerializer(PickleSerializer(), 1024),
- BatchedSerializer(PickleSerializer(), 1024))
+ AutoBatchedSerializer(PickleSerializer()),
+ AutoBatchedSerializer(PickleSerializer()))
ser = CloudPickleSerializer()
pickled_command = ser.dumps(command)
if len(pickled_command) > (1 << 20): # 1M
@@ -1049,18 +1287,20 @@ def registerFunction(self, name, f, returnType=StringType()):
self._sc._javaAccumulator,
returnType.json())
- def inferSchema(self, rdd):
+ def inferSchema(self, rdd, samplingRatio=None):
"""Infer and apply a schema to an RDD of L{Row}.
- We peek at the first row of the RDD to determine the fields' names
- and types. Nested collections are supported, which include array,
- dict, list, Row, tuple, namedtuple, or object.
+ When samplingRatio is specified, the schema is inferred by looking
+ at the types of each row in the sampled dataset. Otherwise, the
+ first 100 rows of the RDD are inspected. Nested collections are
+ supported, which can include array, dict, list, Row, tuple,
+ namedtuple, or object.
- All the rows in `rdd` should have the same type with the first one,
- or it will cause runtime exceptions.
+ Each row could be L{pyspark.sql.Row} object or namedtuple or objects.
+ Using top level dicts is deprecated, as dict is used to represent Maps.
- Each row could be L{pyspark.sql.Row} object or namedtuple or objects,
- using dict is deprecated.
+ If a single column has multiple distinct inferred types, it may cause
+ runtime exceptions.
>>> rdd = sc.parallelize(
... [Row(field1=1, field2="row1"),
@@ -1097,8 +1337,23 @@ def inferSchema(self, rdd):
warnings.warn("Using RDD of dict to inferSchema is deprecated,"
"please use pyspark.sql.Row instead")
- schema = _infer_schema(first)
- rdd = rdd.mapPartitions(lambda rows: _drop_schema(rows, schema))
+ if samplingRatio is None:
+ schema = _infer_schema(first)
+ if _has_nulltype(schema):
+ for row in rdd.take(100)[1:]:
+ schema = _merge_type(schema, _infer_schema(row))
+ if not _has_nulltype(schema):
+ break
+ else:
+ warnings.warn("Some of types cannot be determined by the "
+ "first 100 rows, please try again with sampling")
+ else:
+ if samplingRatio > 0.99:
+ rdd = rdd.sample(False, float(samplingRatio))
+ schema = rdd.map(_infer_schema).reduce(_merge_type)
+
+ converter = _create_converter(schema)
+ rdd = rdd.map(converter)
return self.applySchema(rdd, schema)
def applySchema(self, rdd, schema):
@@ -1184,8 +1439,11 @@ def applySchema(self, rdd, schema):
for row in rows:
_verify_type(row, schema)
- batched = isinstance(rdd._jrdd_deserializer, BatchedSerializer)
- jrdd = self._pythonToJava(rdd._jrdd, batched)
+ # convert python objects to sql data
+ converter = _python_to_sql_converter(schema)
+ rdd = rdd.map(converter)
+
+ jrdd = self._jvm.SerDeUtil.toJavaArray(rdd._to_java_object_rdd())
srdd = self._ssql_ctx.applySchemaToPythonRDD(jrdd.rdd(), schema.json())
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
@@ -1219,7 +1477,7 @@ def parquetFile(self, path):
jschema_rdd = self._ssql_ctx.parquetFile(path).toJavaSchemaRDD()
return SchemaRDD(jschema_rdd, self)
- def jsonFile(self, path, schema=None):
+ def jsonFile(self, path, schema=None, samplingRatio=1.0):
"""
Loads a text file storing one JSON object per line as a
L{SchemaRDD}.
@@ -1227,8 +1485,8 @@ def jsonFile(self, path, schema=None):
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
>>> import tempfile, shutil
>>> jsonFile = tempfile.mkdtemp()
@@ -1274,20 +1532,20 @@ def jsonFile(self, path, schema=None):
[Row(f1=u'row1', f2=None, f3=None)...Row(f1=u'row3', f2=[], f3=None)]
"""
if schema is None:
- srdd = self._ssql_ctx.jsonFile(path)
+ srdd = self._ssql_ctx.jsonFile(path, samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonFile(path, scala_datatype)
return SchemaRDD(srdd.toJavaSchemaRDD(), self)
- def jsonRDD(self, rdd, schema=None):
+ def jsonRDD(self, rdd, schema=None, samplingRatio=1.0):
"""Loads an RDD storing one JSON object per string as a L{SchemaRDD}.
If the schema is provided, applies the given schema to this
JSON dataset.
- Otherwise, it goes through the entire dataset once to determine
- the schema.
+ Otherwise, it samples the dataset with ratio `samplingRatio` to
+ determine the schema.
>>> srdd1 = sqlCtx.jsonRDD(json)
>>> sqlCtx.registerRDDAsTable(srdd1, "table1")
@@ -1344,7 +1602,7 @@ def func(iterator):
keyed._bypass_serializer = True
jrdd = keyed._jrdd.map(self._jvm.BytesToString())
if schema is None:
- srdd = self._ssql_ctx.jsonRDD(jrdd.rdd())
+ srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), samplingRatio)
else:
scala_datatype = self._ssql_ctx.parseDataType(schema.json())
srdd = self._ssql_ctx.jsonRDD(jrdd.rdd(), scala_datatype)
@@ -1582,7 +1840,7 @@ def __init__(self, jschema_rdd, sql_ctx):
self.is_checkpointed = False
self.ctx = self.sql_ctx._sc
# the _jrdd is created by javaToPython(), serialized by pickle
- self._jrdd_deserializer = BatchedSerializer(PickleSerializer())
+ self._jrdd_deserializer = AutoBatchedSerializer(PickleSerializer())
@property
def _jrdd(self):
@@ -1612,6 +1870,21 @@ def limit(self, num):
rdd = self._jschema_rdd.baseSchemaRDD().limit(num).toJavaSchemaRDD()
return SchemaRDD(rdd, self.sql_ctx)
+ def toJSON(self, use_unicode=False):
+ """Convert a SchemaRDD into a MappedRDD of JSON documents; one document per row.
+
+ >>> srdd1 = sqlCtx.jsonRDD(json)
+ >>> sqlCtx.registerRDDAsTable(srdd1, "table1")
+ >>> srdd2 = sqlCtx.sql( "SELECT * from table1")
+ >>> srdd2.toJSON().take(1)[0] == '{"field1":1,"field2":"row1","field3":{"field4":11}}'
+ True
+ >>> srdd3 = sqlCtx.sql( "SELECT field3.field4 from table1")
+ >>> srdd3.toJSON().collect() == ['{"field4":11}', '{"field4":22}', '{"field4":33}']
+ True
+ """
+ rdd = self._jschema_rdd.baseSchemaRDD().toJSON()
+ return RDD(rdd.toJavaRDD(), self._sc, UTF8Deserializer(use_unicode))
+
def saveAsParquetFile(self, path):
"""Save the contents as a Parquet file, preserving the schema.
@@ -1812,15 +2085,13 @@ def subtract(self, other, numPartitions=None):
def _test():
import doctest
- from array import array
from pyspark.context import SparkContext
# let doctest run in pyspark.sql, so DataTypes can be picklable
import pyspark.sql
from pyspark.sql import Row, SQLContext
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
globs = pyspark.sql.__dict__.copy()
- # The small batch size here ensures that we see multiple batches,
- # even in these small test examples:
- sc = SparkContext('local[4]', 'PythonTest', batchSize=2)
+ sc = SparkContext('local[4]', 'PythonTest')
globs['sc'] = sc
globs['sqlCtx'] = SQLContext(sc)
globs['rdd'] = sc.parallelize(
@@ -1828,6 +2099,8 @@ def _test():
Row(field1=2, field2="row2"),
Row(field1=3, field2="row3")]
)
+ globs['ExamplePoint'] = ExamplePoint
+ globs['ExamplePointUDT'] = ExamplePointUDT
jsonStrings = [
'{"field1": 1, "field2": "row1", "field3":{"field4":11}}',
'{"field1" : 2, "field3":{"field4":22, "field5": [10, 11]},'
diff --git a/python/pyspark/streaming/context.py b/python/pyspark/streaming/context.py
index 2f53fbd27b17a..d48f3598e33b2 100644
--- a/python/pyspark/streaming/context.py
+++ b/python/pyspark/streaming/context.py
@@ -142,8 +142,8 @@ def getOrCreate(cls, checkpointPath, setupFunc):
recreated from the checkpoint data. If the data does not exist, then the provided setupFunc
will be used to create a JavaStreamingContext.
- @param checkpointPath Checkpoint directory used in an earlier JavaStreamingContext program
- @param setupFunc Function to create a new JavaStreamingContext and setup DStreams
+ @param checkpointPath: Checkpoint directory used in an earlier JavaStreamingContext program
+ @param setupFunc: Function to create a new JavaStreamingContext and setup DStreams
"""
# TODO: support checkpoint in HDFS
if not os.path.exists(checkpointPath) or not os.listdir(checkpointPath):
diff --git a/python/pyspark/tests.py b/python/pyspark/tests.py
index 37a128907b3a7..bca52a7ce6d58 100644
--- a/python/pyspark/tests.py
+++ b/python/pyspark/tests.py
@@ -32,6 +32,7 @@
import zipfile
import random
import threading
+import hashlib
if sys.version_info[:2] <= (2, 6):
try:
@@ -47,9 +48,10 @@
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int, BatchedSerializer, MarshalSerializer, PickleSerializer, \
- CloudPickleSerializer
+ CloudPickleSerializer, CompressedSerializer
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter
-from pyspark.sql import SQLContext, IntegerType, Row, ArrayType
+from pyspark.sql import SQLContext, IntegerType, Row, ArrayType, StructType, StructField, \
+ UserDefinedType, DoubleType
from pyspark import shuffle
_have_scipy = False
@@ -235,13 +237,24 @@ def foo():
self.assertTrue("exit" in foo.func_code.co_names)
ser.dumps(foo)
+ def test_compressed_serializer(self):
+ ser = CompressedSerializer(PickleSerializer())
+ from StringIO import StringIO
+ io = StringIO()
+ ser.dump_stream(["abc", u"123", range(5)], io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)], list(ser.load_stream(io)))
+ ser.dump_stream(range(1000), io)
+ io.seek(0)
+ self.assertEqual(["abc", u"123", range(5)] + range(1000), list(ser.load_stream(io)))
+
class PySparkTestCase(unittest.TestCase):
def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
- self.sc = SparkContext('local[4]', class_name, batchSize=2)
+ self.sc = SparkContext('local[4]', class_name)
def tearDown(self):
self.sc.stop()
@@ -252,7 +265,7 @@ class ReusedPySparkTestCase(unittest.TestCase):
@classmethod
def setUpClass(cls):
- cls.sc = SparkContext('local[4]', cls.__name__, batchSize=2)
+ cls.sc = SparkContext('local[4]', cls.__name__)
@classmethod
def tearDownClass(cls):
@@ -439,7 +452,7 @@ def test_sampling_default_seed(self):
subset = data.takeSample(False, 10)
self.assertEqual(len(subset), 10)
- def testAggregateByKey(self):
+ def test_aggregate_by_key(self):
data = self.sc.parallelize([(1, 1), (1, 1), (3, 2), (5, 1), (5, 3)], 2)
def seqOp(x, y):
@@ -477,6 +490,32 @@ def test_large_broadcast(self):
m = self.sc.parallelize(range(1), 1).map(lambda x: len(bdata.value)).sum()
self.assertEquals(N, m)
+ def test_multiple_broadcasts(self):
+ N = 1 << 21
+ b1 = self.sc.broadcast(set(range(N))) # multiple blocks in JVM
+ r = range(1 << 15)
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
+ random.shuffle(r)
+ s = str(r)
+ checksum = hashlib.md5(s).hexdigest()
+ b2 = self.sc.broadcast(s)
+ r = list(set(self.sc.parallelize(range(10), 10).map(
+ lambda x: (len(b1.value), hashlib.md5(b2.value).hexdigest())).collect()))
+ self.assertEqual(1, len(r))
+ size, csum = r[0]
+ self.assertEqual(N, size)
+ self.assertEqual(checksum, csum)
+
def test_large_closure(self):
N = 1000000
data = [float(i) for i in xrange(N)]
@@ -494,6 +533,15 @@ def test_zip_with_different_serializers(self):
a = a._reserialize(BatchedSerializer(PickleSerializer(), 2))
b = b._reserialize(MarshalSerializer())
self.assertEqual(a.zip(b).collect(), [(0, 100), (1, 101), (2, 102), (3, 103), (4, 104)])
+ # regression test for SPARK-4841
+ path = os.path.join(SPARK_HOME, "python/test_support/hello.txt")
+ t = self.sc.textFile(path)
+ cnt = t.count()
+ self.assertEqual(cnt, t.zip(t).count())
+ rdd = t.map(str)
+ self.assertEqual(cnt, t.zip(rdd).count())
+ # regression test for bug in _reserializer()
+ self.assertEqual(cnt, t.zip(rdd).count())
def test_zip_with_different_number_of_items(self):
a = self.sc.parallelize(range(5), 2)
@@ -648,6 +696,24 @@ def test_distinct(self):
self.assertEquals(result.getNumPartitions(), 5)
self.assertEquals(result.count(), 3)
+ def test_sort_on_empty_rdd(self):
+ self.assertEqual([], self.sc.parallelize(zip([], [])).sortByKey().collect())
+
+ def test_sample(self):
+ rdd = self.sc.parallelize(range(0, 100), 4)
+ wo = rdd.sample(False, 0.1, 2).collect()
+ wo_dup = rdd.sample(False, 0.1, 2).collect()
+ self.assertSetEqual(set(wo), set(wo_dup))
+ wr = rdd.sample(True, 0.2, 5).collect()
+ wr_dup = rdd.sample(True, 0.2, 5).collect()
+ self.assertSetEqual(set(wr), set(wr_dup))
+ wo_s10 = rdd.sample(False, 0.3, 10).collect()
+ wo_s20 = rdd.sample(False, 0.3, 20).collect()
+ self.assertNotEqual(set(wo_s10), set(wo_s20))
+ wr_s11 = rdd.sample(True, 0.4, 11).collect()
+ wr_s21 = rdd.sample(True, 0.4, 21).collect()
+ self.assertNotEqual(set(wr_s11), set(wr_s21))
+
class ProfilerTests(PySparkTestCase):
@@ -655,7 +721,7 @@ def setUp(self):
self._old_sys_path = list(sys.path)
class_name = self.__class__.__name__
conf = SparkConf().set("spark.python.profile", "true")
- self.sc = SparkContext('local[4]', class_name, batchSize=2, conf=conf)
+ self.sc = SparkContext('local[4]', class_name, conf=conf)
def test_profiler(self):
@@ -679,8 +745,65 @@ def heavy_foo(x):
self.assertTrue("rdd_%d.pstats" % id in os.listdir(d))
+class ExamplePointUDT(UserDefinedType):
+ """
+ User-defined type (UDT) for ExamplePoint.
+ """
+
+ @classmethod
+ def sqlType(self):
+ return ArrayType(DoubleType(), False)
+
+ @classmethod
+ def module(cls):
+ return 'pyspark.tests'
+
+ @classmethod
+ def scalaUDT(cls):
+ return 'org.apache.spark.sql.test.ExamplePointUDT'
+
+ def serialize(self, obj):
+ return [obj.x, obj.y]
+
+ def deserialize(self, datum):
+ return ExamplePoint(datum[0], datum[1])
+
+
+class ExamplePoint:
+ """
+ An example class to demonstrate UDT in Scala, Java, and Python.
+ """
+
+ __UDT__ = ExamplePointUDT()
+
+ def __init__(self, x, y):
+ self.x = x
+ self.y = y
+
+ def __repr__(self):
+ return "ExamplePoint(%s,%s)" % (self.x, self.y)
+
+ def __str__(self):
+ return "(%s,%s)" % (self.x, self.y)
+
+ def __eq__(self, other):
+ return isinstance(other, ExamplePoint) and \
+ other.x == self.x and other.y == self.y
+
+
class SQLTests(ReusedPySparkTestCase):
+ @classmethod
+ def setUpClass(cls):
+ ReusedPySparkTestCase.setUpClass()
+ cls.tempdir = tempfile.NamedTemporaryFile(delete=False)
+ os.unlink(cls.tempdir.name)
+
+ @classmethod
+ def tearDownClass(cls):
+ ReusedPySparkTestCase.tearDownClass()
+ shutil.rmtree(cls.tempdir.name, ignore_errors=True)
+
def setUp(self):
self.sqlCtx = SQLContext(self.sc)
@@ -781,14 +904,67 @@ def test_serialize_nested_array_and_map(self):
self.assertEqual(1.0, row.c)
self.assertEqual("2", row.d)
+ def test_infer_schema(self):
+ d = [Row(l=[], d={}),
+ Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")}, s="")]
+ rdd = self.sc.parallelize(d)
+ srdd = self.sqlCtx.inferSchema(rdd)
+ self.assertEqual([], srdd.map(lambda r: r.l).first())
+ self.assertEqual([None, ""], srdd.map(lambda r: r.s).collect())
+ srdd.registerTempTable("test")
+ result = self.sqlCtx.sql("SELECT l[0].a from test where d['key'].d = '2'")
+ self.assertEqual(1, result.first()[0])
+
+ srdd2 = self.sqlCtx.inferSchema(rdd, 1.0)
+ self.assertEqual(srdd.schema(), srdd2.schema())
+ self.assertEqual({}, srdd2.map(lambda r: r.d).first())
+ self.assertEqual([None, ""], srdd2.map(lambda r: r.s).collect())
+ srdd2.registerTempTable("test2")
+ result = self.sqlCtx.sql("SELECT l[0].a from test2 where d['key'].d = '2'")
+ self.assertEqual(1, result.first()[0])
+
def test_convert_row_to_dict(self):
row = Row(l=[Row(a=1, b='s')], d={"key": Row(c=1.0, d="2")})
self.assertEqual(1, row.asDict()['l'][0].a)
rdd = self.sc.parallelize([row])
srdd = self.sqlCtx.inferSchema(rdd)
srdd.registerTempTable("test")
- row = self.sqlCtx.sql("select l[0].a AS la from test").first()
- self.assertEqual(1, row.asDict()["la"])
+ row = self.sqlCtx.sql("select l, d from test").first()
+ self.assertEqual(1, row.asDict()["l"][0].a)
+ self.assertEqual(1.0, row.asDict()['d']['key'].c)
+
+ def test_infer_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd = self.sqlCtx.inferSchema(rdd)
+ schema = srdd.schema()
+ field = [f for f in schema.fields if f.name == "point"][0]
+ self.assertEqual(type(field.dataType), ExamplePointUDT)
+ srdd.registerTempTable("labeled_point")
+ point = self.sqlCtx.sql("SELECT point FROM labeled_point").first().point
+ self.assertEqual(point, ExamplePoint(1.0, 2.0))
+
+ def test_apply_schema_with_udt(self):
+ from pyspark.tests import ExamplePoint, ExamplePointUDT
+ row = (1.0, ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ schema = StructType([StructField("label", DoubleType(), False),
+ StructField("point", ExamplePointUDT(), False)])
+ srdd = self.sqlCtx.applySchema(rdd, schema)
+ point = srdd.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
+
+ def test_parquet_with_udt(self):
+ from pyspark.tests import ExamplePoint
+ row = Row(label=1.0, point=ExamplePoint(1.0, 2.0))
+ rdd = self.sc.parallelize([row])
+ srdd0 = self.sqlCtx.inferSchema(rdd)
+ output_dir = os.path.join(self.tempdir.name, "labeled_point")
+ srdd0.saveAsParquetFile(output_dir)
+ srdd1 = self.sqlCtx.parquetFile(output_dir)
+ point = srdd1.first().point
+ self.assertEquals(point, ExamplePoint(1.0, 2.0))
class InputFormatTests(ReusedPySparkTestCase):
@@ -887,16 +1063,19 @@ def test_sequencefiles(self):
clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
"org.apache.hadoop.io.Text",
"org.apache.spark.api.python.TestWritable").collect())
- ec = (u'1',
- {u'__class__': u'org.apache.spark.api.python.TestWritable',
- u'double': 54.0, u'int': 123, u'str': u'test1'})
- self.assertEqual(clazz[0], ec)
+ cname = u'org.apache.spark.api.python.TestWritable'
+ ec = [(u'1', {u'__class__': cname, u'double': 1.0, u'int': 1, u'str': u'test1'}),
+ (u'2', {u'__class__': cname, u'double': 2.3, u'int': 2, u'str': u'test2'}),
+ (u'3', {u'__class__': cname, u'double': 3.1, u'int': 3, u'str': u'test3'}),
+ (u'4', {u'__class__': cname, u'double': 4.2, u'int': 4, u'str': u'test4'}),
+ (u'5', {u'__class__': cname, u'double': 5.5, u'int': 5, u'str': u'test56'})]
+ self.assertEqual(clazz, ec)
unbatched_clazz = sorted(self.sc.sequenceFile(basepath + "/sftestdata/sfclass/",
"org.apache.hadoop.io.Text",
"org.apache.spark.api.python.TestWritable",
- batchSize=1).collect())
- self.assertEqual(unbatched_clazz[0], ec)
+ ).collect())
+ self.assertEqual(unbatched_clazz, ec)
def test_oldhadoop(self):
basepath = self.tempdir.name
@@ -982,6 +1161,25 @@ def test_converters(self):
(u'\x03', [2.0])]
self.assertEqual(maps, em)
+ def test_binary_files(self):
+ path = os.path.join(self.tempdir.name, "binaryfiles")
+ os.mkdir(path)
+ data = "short binary data"
+ with open(os.path.join(path, "part-0000"), 'w') as f:
+ f.write(data)
+ [(p, d)] = self.sc.binaryFiles(path).collect()
+ self.assertTrue(p.endswith("part-0000"))
+ self.assertEqual(d, data)
+
+ def test_binary_records(self):
+ path = os.path.join(self.tempdir.name, "binaryrecords")
+ os.mkdir(path)
+ with open(os.path.join(path, "part-0000"), 'w') as f:
+ for i in range(100):
+ f.write('%04d' % i)
+ result = self.sc.binaryRecords(path, 4).map(int).collect()
+ self.assertEqual(range(100), result)
+
class OutputFormatTests(ReusedPySparkTestCase):
@@ -1216,51 +1414,6 @@ def test_reserialization(self):
result5 = sorted(self.sc.sequenceFile(basepath + "/reserialize/newdataset").collect())
self.assertEqual(result5, data)
- def test_unbatched_save_and_read(self):
- basepath = self.tempdir.name
- ei = [(1, u'aa'), (1, u'aa'), (2, u'aa'), (2, u'bb'), (2, u'bb'), (3, u'cc')]
- self.sc.parallelize(ei, len(ei)).saveAsSequenceFile(
- basepath + "/unbatched/")
-
- unbatched_sequence = sorted(self.sc.sequenceFile(
- basepath + "/unbatched/",
- batchSize=1).collect())
- self.assertEqual(unbatched_sequence, ei)
-
- unbatched_hadoopFile = sorted(self.sc.hadoopFile(
- basepath + "/unbatched/",
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- batchSize=1).collect())
- self.assertEqual(unbatched_hadoopFile, ei)
-
- unbatched_newAPIHadoopFile = sorted(self.sc.newAPIHadoopFile(
- basepath + "/unbatched/",
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- batchSize=1).collect())
- self.assertEqual(unbatched_newAPIHadoopFile, ei)
-
- oldconf = {"mapred.input.dir": basepath + "/unbatched/"}
- unbatched_hadoopRDD = sorted(self.sc.hadoopRDD(
- "org.apache.hadoop.mapred.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- conf=oldconf,
- batchSize=1).collect())
- self.assertEqual(unbatched_hadoopRDD, ei)
-
- newconf = {"mapred.input.dir": basepath + "/unbatched/"}
- unbatched_newAPIHadoopRDD = sorted(self.sc.newAPIHadoopRDD(
- "org.apache.hadoop.mapreduce.lib.input.SequenceFileInputFormat",
- "org.apache.hadoop.io.IntWritable",
- "org.apache.hadoop.io.Text",
- conf=newconf,
- batchSize=1).collect())
- self.assertEqual(unbatched_newAPIHadoopRDD, ei)
-
def test_malformed_RDD(self):
basepath = self.tempdir.name
# non-batch-serialized RDD[[(K, V)]] should be rejected
diff --git a/python/pyspark/worker.py b/python/pyspark/worker.py
index 2bdccb5e93f09..7e5343c973dc5 100644
--- a/python/pyspark/worker.py
+++ b/python/pyspark/worker.py
@@ -30,8 +30,7 @@
from pyspark.broadcast import Broadcast, _broadcastRegistry
from pyspark.files import SparkFiles
from pyspark.serializers import write_with_length, write_int, read_long, \
- write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer, \
- CompressedSerializer
+ write_long, read_int, SpecialLengths, UTF8Deserializer, PickleSerializer
from pyspark import shuffle
pickleSer = PickleSerializer()
@@ -78,12 +77,11 @@ def main(infile, outfile):
# fetch names and values of broadcast variables
num_broadcast_variables = read_int(infile)
- ser = CompressedSerializer(pickleSer)
for _ in range(num_broadcast_variables):
bid = read_long(infile)
if bid >= 0:
- value = ser._read_with_length(infile)
- _broadcastRegistry[bid] = Broadcast(bid, value)
+ path = utf8_deserializer.loads(infile)
+ _broadcastRegistry[bid] = Broadcast(path=path)
else:
bid = - bid - 1
_broadcastRegistry.pop(bid)
diff --git a/python/run-tests b/python/run-tests
index a4f0cac059ff3..9ee19ed6e6b26 100755
--- a/python/run-tests
+++ b/python/run-tests
@@ -56,7 +56,7 @@ function run_core_tests() {
run_test "pyspark/conf.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/broadcast.py"
PYSPARK_DOC_TEST=1 run_test "pyspark/accumulators.py"
- PYSPARK_DOC_TEST=1 run_test "pyspark/serializers.py"
+ run_test "pyspark/serializers.py"
run_test "pyspark/shuffle.py"
run_test "pyspark/tests.py"
}
@@ -72,7 +72,7 @@ function run_mllib_tests() {
run_test "pyspark/mllib/clustering.py"
run_test "pyspark/mllib/feature.py"
run_test "pyspark/mllib/linalg.py"
- run_test "pyspark/mllib/random.py"
+ run_test "pyspark/mllib/rand.py"
run_test "pyspark/mllib/recommendation.py"
run_test "pyspark/mllib/regression.py"
run_test "pyspark/mllib/stat.py"
diff --git a/repl/pom.xml b/repl/pom.xml
index af528c8914335..16408b2c71b9f 100644
--- a/repl/pom.xml
+++ b/repl/pom.xml
@@ -21,7 +21,7 @@
org.apache.sparkspark-parent
- 1.2.0-SNAPSHOT
+ 1.2.1-palantir2../pom.xml
@@ -35,9 +35,16 @@
repl/usr/share/sparkroot
+ scala-2.10/src/main/scala
+ scala-2.10/src/test/scala
+
+ ${jline.groupid}
+ jline
+ ${jline.version}
+ org.apache.sparkspark-core_${scala.binary.version}
@@ -75,11 +82,6 @@
scala-reflect${scala.version}
-
- org.scala-lang
- jline
- ${scala.version}
- org.slf4jjul-to-slf4j
@@ -99,20 +101,6 @@
target/scala-${scala.binary.version}/classestarget/scala-${scala.binary.version}/test-classes
-
- org.apache.maven.plugins
- maven-deploy-plugin
-
- true
-
-
-
- org.apache.maven.plugins
- maven-install-plugin
-
- true
-
- org.scalatestscalatest-maven-plugin
@@ -122,6 +110,51 @@
+
+
+ org.codehaus.mojo
+ build-helper-maven-plugin
+
+
+ add-scala-sources
+ generate-sources
+
+ add-source
+
+
+
+ src/main/scala
+ ${extra.source.dir}
+
+
+
+
+ add-scala-test-sources
+ generate-test-sources
+
+ add-test-source
+
+
+
+ src/test/scala
+ ${extra.testsource.dir}
+
+
+
+
+
+
+
+ scala-2.11
+
+ scala-2.11
+
+
+ scala-2.11/src/main/scala
+ scala-2.11/src/test/scala
+
+
+
diff --git a/repl/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/Main.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/Main.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkCommandLine.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkHelper.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkHelper.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkILoop.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoop.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
similarity index 95%
rename from repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
index 7667a9c11979e..da4286c5e4874 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkILoopInit.scala
@@ -121,11 +121,14 @@ trait SparkILoopInit {
def initializeSpark() {
intp.beQuietDuring {
command("""
- @transient val sc = org.apache.spark.repl.Main.interp.createSparkContext();
+ @transient val sc = {
+ val _sc = org.apache.spark.repl.Main.interp.createSparkContext()
+ println("Spark context available as sc.")
+ _sc
+ }
""")
command("import org.apache.spark.SparkContext._")
}
- echo("Spark context available as sc.")
}
// code to be executed only after the interpreter is initialized
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
similarity index 99%
rename from repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
index 646c68e60c2e9..b646f0b6f0868 100644
--- a/repl/src/main/scala/org/apache/spark/repl/SparkIMain.scala
+++ b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -106,7 +106,7 @@ import org.apache.spark.util.Utils
val virtualDirectory = new PlainFile(outputDir) // "directory" for classfiles
/** Jetty server that will serve our classes to worker nodes */
val classServerPort = conf.getInt("spark.replClassServer.port", 0)
- val classServer = new HttpServer(outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
+ val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf), classServerPort, "HTTP class server")
private var currentSettings: Settings = initialSettings
var printResults = true // whether to print result lines
var totalSilence = false // whether to print anything
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkImports.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkImports.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkImports.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineCompletion.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkJLineReader.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkMemberHandlers.scala
diff --git a/repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala b/repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala
similarity index 100%
rename from repl/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala
rename to repl/scala-2.10/src/main/scala/org/apache/spark/repl/SparkRunnerSettings.scala
diff --git a/repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala b/repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
similarity index 100%
rename from repl/src/test/scala/org/apache/spark/repl/ReplSuite.scala
rename to repl/scala-2.10/src/test/scala/org/apache/spark/repl/ReplSuite.scala
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
new file mode 100644
index 0000000000000..69e44d4f916e1
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/Main.scala
@@ -0,0 +1,85 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.repl
+
+import org.apache.spark.util.Utils
+import org.apache.spark._
+
+import scala.tools.nsc.Settings
+import scala.tools.nsc.interpreter.SparkILoop
+
+object Main extends Logging {
+
+ val conf = new SparkConf()
+ val tmp = System.getProperty("java.io.tmpdir")
+ val rootDir = conf.get("spark.repl.classdir", tmp)
+ val outputDir = Utils.createTempDir(rootDir)
+ val s = new Settings()
+ s.processArguments(List("-Yrepl-class-based",
+ "-Yrepl-outdir", s"${outputDir.getAbsolutePath}", "-Yrepl-sync"), true)
+ val classServer = new HttpServer(conf, outputDir, new SecurityManager(conf))
+ var sparkContext: SparkContext = _
+ var interp = new SparkILoop // this is a public var because tests reset it.
+
+ def main(args: Array[String]) {
+ if (getMaster == "yarn-client") System.setProperty("SPARK_YARN_MODE", "true")
+ // Start the classServer and store its URI in a spark system property
+ // (which will be passed to executors so that they can connect to it)
+ classServer.start()
+ interp.process(s) // Repl starts and goes in loop of R.E.P.L
+ classServer.stop()
+ Option(sparkContext).map(_.stop)
+ }
+
+
+ def getAddedJars: Array[String] = {
+ val envJars = sys.env.get("ADD_JARS")
+ val propJars = sys.props.get("spark.jars").flatMap { p => if (p == "") None else Some(p) }
+ val jars = propJars.orElse(envJars).getOrElse("")
+ Utils.resolveURIs(jars).split(",").filter(_.nonEmpty)
+ }
+
+ def createSparkContext(): SparkContext = {
+ val execUri = System.getenv("SPARK_EXECUTOR_URI")
+ val jars = getAddedJars
+ val conf = new SparkConf()
+ .setMaster(getMaster)
+ .setAppName("Spark shell")
+ .setJars(jars)
+ .set("spark.repl.class.uri", classServer.uri)
+ logInfo("Spark class server started at " + classServer.uri)
+ if (execUri != null) {
+ conf.set("spark.executor.uri", execUri)
+ }
+ if (System.getenv("SPARK_HOME") != null) {
+ conf.setSparkHome(System.getenv("SPARK_HOME"))
+ }
+ sparkContext = new SparkContext(conf)
+ logInfo("Created spark context..")
+ sparkContext
+ }
+
+ private def getMaster: String = {
+ val master = {
+ val envMaster = sys.env.get("MASTER")
+ val propMaster = sys.props.get("spark.master")
+ propMaster.orElse(envMaster).getOrElse("local[*]")
+ }
+ master
+ }
+}
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
new file mode 100644
index 0000000000000..8e519fa67f649
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkExprTyper.scala
@@ -0,0 +1,86 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Paul Phillips
+ */
+
+package scala.tools.nsc
+package interpreter
+
+import scala.tools.nsc.ast.parser.Tokens.EOF
+
+trait SparkExprTyper {
+ val repl: SparkIMain
+
+ import repl._
+ import global.{ reporter => _, Import => _, _ }
+ import naming.freshInternalVarName
+
+ def symbolOfLine(code: String): Symbol = {
+ def asExpr(): Symbol = {
+ val name = freshInternalVarName()
+ // Typing it with a lazy val would give us the right type, but runs
+ // into compiler bugs with things like existentials, so we compile it
+ // behind a def and strip the NullaryMethodType which wraps the expr.
+ val line = "def " + name + " = " + code
+
+ interpretSynthetic(line) match {
+ case IR.Success =>
+ val sym0 = symbolOfTerm(name)
+ // drop NullaryMethodType
+ sym0.cloneSymbol setInfo exitingTyper(sym0.tpe_*.finalResultType)
+ case _ => NoSymbol
+ }
+ }
+ def asDefn(): Symbol = {
+ val old = repl.definedSymbolList.toSet
+
+ interpretSynthetic(code) match {
+ case IR.Success =>
+ repl.definedSymbolList filterNot old match {
+ case Nil => NoSymbol
+ case sym :: Nil => sym
+ case syms => NoSymbol.newOverloaded(NoPrefix, syms)
+ }
+ case _ => NoSymbol
+ }
+ }
+ def asError(): Symbol = {
+ interpretSynthetic(code)
+ NoSymbol
+ }
+ beSilentDuring(asExpr()) orElse beSilentDuring(asDefn()) orElse asError()
+ }
+
+ private var typeOfExpressionDepth = 0
+ def typeOfExpression(expr: String, silent: Boolean = true): Type = {
+ if (typeOfExpressionDepth > 2) {
+ repldbg("Terminating typeOfExpression recursion for expression: " + expr)
+ return NoType
+ }
+ typeOfExpressionDepth += 1
+ // Don't presently have a good way to suppress undesirable success output
+ // while letting errors through, so it is first trying it silently: if there
+ // is an error, and errors are desired, then it re-evaluates non-silently
+ // to induce the error message.
+ try beSilentDuring(symbolOfLine(expr).tpe) match {
+ case NoType if !silent => symbolOfLine(expr).tpe // generate error
+ case tpe => tpe
+ }
+ finally typeOfExpressionDepth -= 1
+ }
+
+ // This only works for proper types.
+ def typeOfTypeString(typeString: String): Type = {
+ def asProperType(): Option[Type] = {
+ val name = freshInternalVarName()
+ val line = "def %s: %s = ???" format (name, typeString)
+ interpretSynthetic(line) match {
+ case IR.Success =>
+ val sym0 = symbolOfTerm(name)
+ Some(sym0.asMethod.returnType)
+ case _ => None
+ }
+ }
+ beSilentDuring(asProperType()) getOrElse NoType
+ }
+}
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
new file mode 100644
index 0000000000000..250727305970d
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkILoop.scala
@@ -0,0 +1,969 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Alexander Spoon
+ */
+
+package scala
+package tools.nsc
+package interpreter
+
+import scala.language.{ implicitConversions, existentials }
+import scala.annotation.tailrec
+import Predef.{ println => _, _ }
+import interpreter.session._
+import StdReplTags._
+import scala.reflect.api.{Mirror, Universe, TypeCreator}
+import scala.util.Properties.{ jdkHome, javaVersion, versionString, javaVmName }
+import scala.tools.nsc.util.{ ClassPath, Exceptional, stringFromWriter, stringFromStream }
+import scala.reflect.{ClassTag, classTag}
+import scala.reflect.internal.util.{ BatchSourceFile, ScalaClassLoader }
+import ScalaClassLoader._
+import scala.reflect.io.{ File, Directory }
+import scala.tools.util._
+import scala.collection.generic.Clearable
+import scala.concurrent.{ ExecutionContext, Await, Future, future }
+import ExecutionContext.Implicits._
+import java.io.{ BufferedReader, FileReader }
+
+/** The Scala interactive shell. It provides a read-eval-print loop
+ * around the Interpreter class.
+ * After instantiation, clients should call the main() method.
+ *
+ * If no in0 is specified, then input will come from the console, and
+ * the class will attempt to provide input editing feature such as
+ * input history.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ * @version 1.2
+ */
+class SparkILoop(in0: Option[BufferedReader], protected val out: JPrintWriter)
+ extends AnyRef
+ with LoopCommands
+{
+ def this(in0: BufferedReader, out: JPrintWriter) = this(Some(in0), out)
+ def this() = this(None, new JPrintWriter(Console.out, true))
+//
+// @deprecated("Use `intp` instead.", "2.9.0") def interpreter = intp
+// @deprecated("Use `intp` instead.", "2.9.0") def interpreter_= (i: Interpreter): Unit = intp = i
+
+ var in: InteractiveReader = _ // the input stream from which commands come
+ var settings: Settings = _
+ var intp: SparkIMain = _
+
+ var globalFuture: Future[Boolean] = _
+
+ protected def asyncMessage(msg: String) {
+ if (isReplInfo || isReplPower)
+ echoAndRefresh(msg)
+ }
+
+ def initializeSpark() {
+ intp.beQuietDuring {
+ command( """
+ @transient val sc = {
+ val _sc = org.apache.spark.repl.Main.createSparkContext()
+ println("Spark context available as sc.")
+ _sc
+ }
+ """)
+ command("import org.apache.spark.SparkContext._")
+ }
+ }
+
+ /** Print a welcome message */
+ def printWelcome() {
+ import org.apache.spark.SPARK_VERSION
+ echo("""Welcome to
+ ____ __
+ / __/__ ___ _____/ /__
+ _\ \/ _ \/ _ `/ __/ '_/
+ /___/ .__/\_,_/_/ /_/\_\ version %s
+ /_/
+ """.format(SPARK_VERSION))
+ val welcomeMsg = "Using Scala %s (%s, Java %s)".format(
+ versionString, javaVmName, javaVersion)
+ echo(welcomeMsg)
+ echo("Type in expressions to have them evaluated.")
+ echo("Type :help for more information.")
+ }
+
+ override def echoCommandMessage(msg: String) {
+ intp.reporter printUntruncatedMessage msg
+ }
+
+ // lazy val power = new Power(intp, new StdReplVals(this))(tagOfStdReplVals, classTag[StdReplVals])
+ def history = in.history
+
+ // classpath entries added via :cp
+ var addedClasspath: String = ""
+
+ /** A reverse list of commands to replay if the user requests a :replay */
+ var replayCommandStack: List[String] = Nil
+
+ /** A list of commands to replay if the user requests a :replay */
+ def replayCommands = replayCommandStack.reverse
+
+ /** Record a command for replay should the user request a :replay */
+ def addReplay(cmd: String) = replayCommandStack ::= cmd
+
+ def savingReplayStack[T](body: => T): T = {
+ val saved = replayCommandStack
+ try body
+ finally replayCommandStack = saved
+ }
+ def savingReader[T](body: => T): T = {
+ val saved = in
+ try body
+ finally in = saved
+ }
+
+ /** Close the interpreter and set the var to null. */
+ def closeInterpreter() {
+ if (intp ne null) {
+ intp.close()
+ intp = null
+ }
+ }
+
+ class SparkILoopInterpreter extends SparkIMain(settings, out) {
+ outer =>
+
+ override lazy val formatting = new Formatting {
+ def prompt = SparkILoop.this.prompt
+ }
+ override protected def parentClassLoader =
+ settings.explicitParentLoader.getOrElse( classOf[SparkILoop].getClassLoader )
+ }
+
+ /** Create a new interpreter. */
+ def createInterpreter() {
+ if (addedClasspath != "")
+ settings.classpath append addedClasspath
+
+ intp = new SparkILoopInterpreter
+ }
+
+ /** print a friendly help message */
+ def helpCommand(line: String): Result = {
+ if (line == "") helpSummary()
+ else uniqueCommand(line) match {
+ case Some(lc) => echo("\n" + lc.help)
+ case _ => ambiguousError(line)
+ }
+ }
+ private def helpSummary() = {
+ val usageWidth = commands map (_.usageMsg.length) max
+ val formatStr = "%-" + usageWidth + "s %s"
+
+ echo("All commands can be abbreviated, e.g. :he instead of :help.")
+
+ commands foreach { cmd =>
+ echo(formatStr.format(cmd.usageMsg, cmd.help))
+ }
+ }
+ private def ambiguousError(cmd: String): Result = {
+ matchingCommands(cmd) match {
+ case Nil => echo(cmd + ": no such command. Type :help for help.")
+ case xs => echo(cmd + " is ambiguous: did you mean " + xs.map(":" + _.name).mkString(" or ") + "?")
+ }
+ Result(keepRunning = true, None)
+ }
+ private def matchingCommands(cmd: String) = commands filter (_.name startsWith cmd)
+ private def uniqueCommand(cmd: String): Option[LoopCommand] = {
+ // this lets us add commands willy-nilly and only requires enough command to disambiguate
+ matchingCommands(cmd) match {
+ case List(x) => Some(x)
+ // exact match OK even if otherwise appears ambiguous
+ case xs => xs find (_.name == cmd)
+ }
+ }
+
+ /** Show the history */
+ lazy val historyCommand = new LoopCommand("history", "show the history (optional num is commands to show)") {
+ override def usage = "[num]"
+ def defaultLines = 20
+
+ def apply(line: String): Result = {
+ if (history eq NoHistory)
+ return "No history available."
+
+ val xs = words(line)
+ val current = history.index
+ val count = try xs.head.toInt catch { case _: Exception => defaultLines }
+ val lines = history.asStrings takeRight count
+ val offset = current - lines.size + 1
+
+ for ((line, index) <- lines.zipWithIndex)
+ echo("%3d %s".format(index + offset, line))
+ }
+ }
+
+ // When you know you are most likely breaking into the middle
+ // of a line being typed. This softens the blow.
+ protected def echoAndRefresh(msg: String) = {
+ echo("\n" + msg)
+ in.redrawLine()
+ }
+ protected def echo(msg: String) = {
+ out println msg
+ out.flush()
+ }
+
+ /** Search the history */
+ def searchHistory(_cmdline: String) {
+ val cmdline = _cmdline.toLowerCase
+ val offset = history.index - history.size + 1
+
+ for ((line, index) <- history.asStrings.zipWithIndex ; if line.toLowerCase contains cmdline)
+ echo("%d %s".format(index + offset, line))
+ }
+
+ private val currentPrompt = Properties.shellPromptString
+
+ /** Prompt to print when awaiting input */
+ def prompt = currentPrompt
+
+ import LoopCommand.{ cmd, nullary }
+
+ /** Standard commands **/
+ lazy val standardCommands = List(
+ cmd("cp", "", "add a jar or directory to the classpath", addClasspath),
+ cmd("edit", "|", "edit history", editCommand),
+ cmd("help", "[command]", "print this summary or command-specific help", helpCommand),
+ historyCommand,
+ cmd("h?", "", "search the history", searchHistory),
+ cmd("imports", "[name name ...]", "show import history, identifying sources of names", importsCommand),
+ //cmd("implicits", "[-v]", "show the implicits in scope", intp.implicitsCommand),
+ cmd("javap", "", "disassemble a file or class name", javapCommand),
+ cmd("line", "|", "place line(s) at the end of history", lineCommand),
+ cmd("load", "", "interpret lines in a file", loadCommand),
+ cmd("paste", "[-raw] [path]", "enter paste mode or paste a file", pasteCommand),
+ // nullary("power", "enable power user mode", powerCmd),
+ nullary("quit", "exit the interpreter", () => Result(keepRunning = false, None)),
+ nullary("replay", "reset execution and replay all previous commands", replay),
+ nullary("reset", "reset the repl to its initial state, forgetting all session entries", resetCommand),
+ cmd("save", "", "save replayable session to a file", saveCommand),
+ shCommand,
+ cmd("settings", "[+|-]", "+enable/-disable flags, set compiler options", changeSettings),
+ nullary("silent", "disable/enable automatic printing of results", verbosity),
+// cmd("type", "[-v] ", "display the type of an expression without evaluating it", typeCommand),
+// cmd("kind", "[-v] ", "display the kind of expression's type", kindCommand),
+ nullary("warnings", "show the suppressed warnings from the most recent line which had any", warningsCommand)
+ )
+
+ /** Power user commands */
+// lazy val powerCommands: List[LoopCommand] = List(
+// cmd("phase", "", "set the implicit phase for power commands", phaseCommand)
+// )
+
+ private def importsCommand(line: String): Result = {
+ val tokens = words(line)
+ val handlers = intp.languageWildcardHandlers ++ intp.importHandlers
+
+ handlers.filterNot(_.importedSymbols.isEmpty).zipWithIndex foreach {
+ case (handler, idx) =>
+ val (types, terms) = handler.importedSymbols partition (_.name.isTypeName)
+ val imps = handler.implicitSymbols
+ val found = tokens filter (handler importsSymbolNamed _)
+ val typeMsg = if (types.isEmpty) "" else types.size + " types"
+ val termMsg = if (terms.isEmpty) "" else terms.size + " terms"
+ val implicitMsg = if (imps.isEmpty) "" else imps.size + " are implicit"
+ val foundMsg = if (found.isEmpty) "" else found.mkString(" // imports: ", ", ", "")
+ val statsMsg = List(typeMsg, termMsg, implicitMsg) filterNot (_ == "") mkString ("(", ", ", ")")
+
+ intp.reporter.printMessage("%2d) %-30s %s%s".format(
+ idx + 1,
+ handler.importString,
+ statsMsg,
+ foundMsg
+ ))
+ }
+ }
+
+ private def findToolsJar() = PathResolver.SupplementalLocations.platformTools
+
+ private def addToolsJarToLoader() = {
+ val cl = findToolsJar() match {
+ case Some(tools) => ScalaClassLoader.fromURLs(Seq(tools.toURL), intp.classLoader)
+ case _ => intp.classLoader
+ }
+ if (Javap.isAvailable(cl)) {
+ repldbg(":javap available.")
+ cl
+ }
+ else {
+ repldbg(":javap unavailable: no tools.jar at " + jdkHome)
+ intp.classLoader
+ }
+ }
+//
+// protected def newJavap() =
+// JavapClass(addToolsJarToLoader(), new IMain.ReplStrippingWriter(intp), Some(intp))
+//
+// private lazy val javap = substituteAndLog[Javap]("javap", NoJavap)(newJavap())
+
+ // Still todo: modules.
+// private def typeCommand(line0: String): Result = {
+// line0.trim match {
+// case "" => ":type [-v] "
+// case s => intp.typeCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
+// }
+// }
+
+// private def kindCommand(expr: String): Result = {
+// expr.trim match {
+// case "" => ":kind [-v] "
+// case s => intp.kindCommandInternal(s stripPrefix "-v " trim, verbose = s startsWith "-v ")
+// }
+// }
+
+ private def warningsCommand(): Result = {
+ if (intp.lastWarnings.isEmpty)
+ "Can't find any cached warnings."
+ else
+ intp.lastWarnings foreach { case (pos, msg) => intp.reporter.warning(pos, msg) }
+ }
+
+ private def changeSettings(args: String): Result = {
+ def showSettings() = {
+ for (s <- settings.userSetSettings.toSeq.sorted) echo(s.toString)
+ }
+ def updateSettings() = {
+ // put aside +flag options
+ val (pluses, rest) = (args split "\\s+").toList partition (_.startsWith("+"))
+ val tmps = new Settings
+ val (ok, leftover) = tmps.processArguments(rest, processAll = true)
+ if (!ok) echo("Bad settings request.")
+ else if (leftover.nonEmpty) echo("Unprocessed settings.")
+ else {
+ // boolean flags set-by-user on tmp copy should be off, not on
+ val offs = tmps.userSetSettings filter (_.isInstanceOf[Settings#BooleanSetting])
+ val (minuses, nonbools) = rest partition (arg => offs exists (_ respondsTo arg))
+ // update non-flags
+ settings.processArguments(nonbools, processAll = true)
+ // also snag multi-value options for clearing, e.g. -Ylog: and -language:
+ for {
+ s <- settings.userSetSettings
+ if s.isInstanceOf[Settings#MultiStringSetting] || s.isInstanceOf[Settings#PhasesSetting]
+ if nonbools exists (arg => arg.head == '-' && arg.last == ':' && (s respondsTo arg.init))
+ } s match {
+ case c: Clearable => c.clear()
+ case _ =>
+ }
+ def update(bs: Seq[String], name: String=>String, setter: Settings#Setting=>Unit) = {
+ for (b <- bs)
+ settings.lookupSetting(name(b)) match {
+ case Some(s) =>
+ if (s.isInstanceOf[Settings#BooleanSetting]) setter(s)
+ else echo(s"Not a boolean flag: $b")
+ case _ =>
+ echo(s"Not an option: $b")
+ }
+ }
+ update(minuses, identity, _.tryToSetFromPropertyValue("false")) // turn off
+ update(pluses, "-" + _.drop(1), _.tryToSet(Nil)) // turn on
+ }
+ }
+ if (args.isEmpty) showSettings() else updateSettings()
+ }
+
+ private def javapCommand(line: String): Result = {
+// if (javap == null)
+// ":javap unavailable, no tools.jar at %s. Set JDK_HOME.".format(jdkHome)
+// else if (line == "")
+// ":javap [-lcsvp] [path1 path2 ...]"
+// else
+// javap(words(line)) foreach { res =>
+// if (res.isError) return "Failed: " + res.value
+// else res.show()
+// }
+ }
+
+ private def pathToPhaseWrapper = intp.originalPath("$r") + ".phased.atCurrent"
+
+ private def phaseCommand(name: String): Result = {
+// val phased: Phased = power.phased
+// import phased.NoPhaseName
+//
+// if (name == "clear") {
+// phased.set(NoPhaseName)
+// intp.clearExecutionWrapper()
+// "Cleared active phase."
+// }
+// else if (name == "") phased.get match {
+// case NoPhaseName => "Usage: :phase (e.g. typer, erasure.next, erasure+3)"
+// case ph => "Active phase is '%s'. (To clear, :phase clear)".format(phased.get)
+// }
+// else {
+// val what = phased.parse(name)
+// if (what.isEmpty || !phased.set(what))
+// "'" + name + "' does not appear to represent a valid phase."
+// else {
+// intp.setExecutionWrapper(pathToPhaseWrapper)
+// val activeMessage =
+// if (what.toString.length == name.length) "" + what
+// else "%s (%s)".format(what, name)
+//
+// "Active phase is now: " + activeMessage
+// }
+// }
+ }
+
+ /** Available commands */
+ def commands: List[LoopCommand] = standardCommands ++ (
+ // if (isReplPower)
+ // powerCommands
+ // else
+ Nil
+ )
+
+ val replayQuestionMessage =
+ """|That entry seems to have slain the compiler. Shall I replay
+ |your session? I can re-run each line except the last one.
+ |[y/n]
+ """.trim.stripMargin
+
+ private val crashRecovery: PartialFunction[Throwable, Boolean] = {
+ case ex: Throwable =>
+ val (err, explain) = (
+ if (intp.isInitializeComplete)
+ (intp.global.throwableAsString(ex), "")
+ else
+ (ex.getMessage, "The compiler did not initialize.\n")
+ )
+ echo(err)
+
+ ex match {
+ case _: NoSuchMethodError | _: NoClassDefFoundError =>
+ echo("\nUnrecoverable error.")
+ throw ex
+ case _ =>
+ def fn(): Boolean =
+ try in.readYesOrNo(explain + replayQuestionMessage, { echo("\nYou must enter y or n.") ; fn() })
+ catch { case _: RuntimeException => false }
+
+ if (fn()) replay()
+ else echo("\nAbandoning crashed session.")
+ }
+ true
+ }
+
+ // return false if repl should exit
+ def processLine(line: String): Boolean = {
+ import scala.concurrent.duration._
+ Await.ready(globalFuture, 60.seconds)
+
+ (line ne null) && (command(line) match {
+ case Result(false, _) => false
+ case Result(_, Some(line)) => addReplay(line) ; true
+ case _ => true
+ })
+ }
+
+ private def readOneLine() = {
+ out.flush()
+ in readLine prompt
+ }
+
+ /** The main read-eval-print loop for the repl. It calls
+ * command() for each line of input, and stops when
+ * command() returns false.
+ */
+ @tailrec final def loop() {
+ if ( try processLine(readOneLine()) catch crashRecovery )
+ loop()
+ }
+
+ /** interpret all lines from a specified file */
+ def interpretAllFrom(file: File) {
+ savingReader {
+ savingReplayStack {
+ file applyReader { reader =>
+ in = SimpleReader(reader, out, interactive = false)
+ echo("Loading " + file + "...")
+ loop()
+ }
+ }
+ }
+ }
+
+ /** create a new interpreter and replay the given commands */
+ def replay() {
+ reset()
+ if (replayCommandStack.isEmpty)
+ echo("Nothing to replay.")
+ else for (cmd <- replayCommands) {
+ echo("Replaying: " + cmd) // flush because maybe cmd will have its own output
+ command(cmd)
+ echo("")
+ }
+ }
+ def resetCommand() {
+ echo("Resetting interpreter state.")
+ if (replayCommandStack.nonEmpty) {
+ echo("Forgetting this session history:\n")
+ replayCommands foreach echo
+ echo("")
+ replayCommandStack = Nil
+ }
+ if (intp.namedDefinedTerms.nonEmpty)
+ echo("Forgetting all expression results and named terms: " + intp.namedDefinedTerms.mkString(", "))
+ if (intp.definedTypes.nonEmpty)
+ echo("Forgetting defined types: " + intp.definedTypes.mkString(", "))
+
+ reset()
+ }
+ def reset() {
+ intp.reset()
+ unleashAndSetPhase()
+ }
+
+ def lineCommand(what: String): Result = editCommand(what, None)
+
+ // :edit id or :edit line
+ def editCommand(what: String): Result = editCommand(what, Properties.envOrNone("EDITOR"))
+
+ def editCommand(what: String, editor: Option[String]): Result = {
+ def diagnose(code: String) = {
+ echo("The edited code is incomplete!\n")
+ val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
+ if (errless) echo("The compiler reports no errors.")
+ }
+ def historicize(text: String) = history match {
+ case jlh: JLineHistory => text.lines foreach jlh.add ; jlh.moveToEnd() ; true
+ case _ => false
+ }
+ def edit(text: String): Result = editor match {
+ case Some(ed) =>
+ val tmp = File.makeTemp()
+ tmp.writeAll(text)
+ try {
+ val pr = new ProcessResult(s"$ed ${tmp.path}")
+ pr.exitCode match {
+ case 0 =>
+ tmp.safeSlurp() match {
+ case Some(edited) if edited.trim.isEmpty => echo("Edited text is empty.")
+ case Some(edited) =>
+ echo(edited.lines map ("+" + _) mkString "\n")
+ val res = intp interpret edited
+ if (res == IR.Incomplete) diagnose(edited)
+ else {
+ historicize(edited)
+ Result(lineToRecord = Some(edited), keepRunning = true)
+ }
+ case None => echo("Can't read edited text. Did you delete it?")
+ }
+ case x => echo(s"Error exit from $ed ($x), ignoring")
+ }
+ } finally {
+ tmp.delete()
+ }
+ case None =>
+ if (historicize(text)) echo("Placing text in recent history.")
+ else echo(f"No EDITOR defined and you can't change history, echoing your text:%n$text")
+ }
+
+ // if what is a number, use it as a line number or range in history
+ def isNum = what forall (c => c.isDigit || c == '-' || c == '+')
+ // except that "-" means last value
+ def isLast = (what == "-")
+ if (isLast || !isNum) {
+ val name = if (isLast) intp.mostRecentVar else what
+ val sym = intp.symbolOfIdent(name)
+ intp.prevRequestList collectFirst { case r if r.defines contains sym => r } match {
+ case Some(req) => edit(req.line)
+ case None => echo(s"No symbol in scope: $what")
+ }
+ } else try {
+ val s = what
+ // line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)
+ val (start, len) =
+ if ((s indexOf '+') > 0) {
+ val (a,b) = s splitAt (s indexOf '+')
+ (a.toInt, b.drop(1).toInt)
+ } else {
+ (s indexOf '-') match {
+ case -1 => (s.toInt, 1)
+ case 0 => val n = s.drop(1).toInt ; (history.index - n, n)
+ case _ if s.last == '-' => val n = s.init.toInt ; (n, history.index - n)
+ case i => val n = s.take(i).toInt ; (n, s.drop(i+1).toInt - n)
+ }
+ }
+ import scala.collection.JavaConverters._
+ val index = (start - 1) max 0
+ val text = history match {
+ case jlh: JLineHistory => jlh.entries(index).asScala.take(len) map (_.value) mkString "\n"
+ case _ => history.asStrings.slice(index, index + len) mkString "\n"
+ }
+ edit(text)
+ } catch {
+ case _: NumberFormatException => echo(s"Bad range '$what'")
+ echo("Use line 123, 120+3, -3, 120-123, 120-, note -3 is not 0-3 but (cur-3,cur)")
+ }
+ }
+
+ /** fork a shell and run a command */
+ lazy val shCommand = new LoopCommand("sh", "run a shell command (result is implicitly => List[String])") {
+ override def usage = ""
+ def apply(line: String): Result = line match {
+ case "" => showUsage()
+ case _ =>
+ val toRun = s"new ${classOf[ProcessResult].getName}(${string2codeQuoted(line)})"
+ intp interpret toRun
+ ()
+ }
+ }
+
+ def withFile[A](filename: String)(action: File => A): Option[A] = {
+ val res = Some(File(filename)) filter (_.exists) map action
+ if (res.isEmpty) echo("That file does not exist") // courtesy side-effect
+ res
+ }
+
+ def loadCommand(arg: String) = {
+ var shouldReplay: Option[String] = None
+ withFile(arg)(f => {
+ interpretAllFrom(f)
+ shouldReplay = Some(":load " + arg)
+ })
+ Result(keepRunning = true, shouldReplay)
+ }
+
+ def saveCommand(filename: String): Result = (
+ if (filename.isEmpty) echo("File name is required.")
+ else if (replayCommandStack.isEmpty) echo("No replay commands in session")
+ else File(filename).printlnAll(replayCommands: _*)
+ )
+
+ def addClasspath(arg: String): Unit = {
+ val f = File(arg).normalize
+ if (f.exists) {
+ addedClasspath = ClassPath.join(addedClasspath, f.path)
+ val totalClasspath = ClassPath.join(settings.classpath.value, addedClasspath)
+ echo("Added '%s'. Your new classpath is:\n\"%s\"".format(f.path, totalClasspath))
+ replay()
+ }
+ else echo("The path '" + f + "' doesn't seem to exist.")
+ }
+
+ def powerCmd(): Result = {
+ if (isReplPower) "Already in power mode."
+ else enablePowerMode(isDuringInit = false)
+ }
+ def enablePowerMode(isDuringInit: Boolean) = {
+ replProps.power setValue true
+ unleashAndSetPhase()
+ // asyncEcho(isDuringInit, power.banner)
+ }
+ private def unleashAndSetPhase() {
+ if (isReplPower) {
+ // power.unleash()
+ // Set the phase to "typer"
+ // intp beSilentDuring phaseCommand("typer")
+ }
+ }
+
+ def asyncEcho(async: Boolean, msg: => String) {
+ if (async) asyncMessage(msg)
+ else echo(msg)
+ }
+
+ def verbosity() = {
+ val old = intp.printResults
+ intp.printResults = !old
+ echo("Switched " + (if (old) "off" else "on") + " result printing.")
+ }
+
+ /** Run one command submitted by the user. Two values are returned:
+ * (1) whether to keep running, (2) the line to record for replay,
+ * if any. */
+ def command(line: String): Result = {
+ if (line startsWith ":") {
+ val cmd = line.tail takeWhile (x => !x.isWhitespace)
+ uniqueCommand(cmd) match {
+ case Some(lc) => lc(line.tail stripPrefix cmd dropWhile (_.isWhitespace))
+ case _ => ambiguousError(cmd)
+ }
+ }
+ else if (intp.global == null) Result(keepRunning = false, None) // Notice failure to create compiler
+ else Result(keepRunning = true, interpretStartingWith(line))
+ }
+
+ private def readWhile(cond: String => Boolean) = {
+ Iterator continually in.readLine("") takeWhile (x => x != null && cond(x))
+ }
+
+ def pasteCommand(arg: String): Result = {
+ var shouldReplay: Option[String] = None
+ def result = Result(keepRunning = true, shouldReplay)
+ val (raw, file) =
+ if (arg.isEmpty) (false, None)
+ else {
+ val r = """(-raw)?(\s+)?([^\-]\S*)?""".r
+ arg match {
+ case r(flag, sep, name) =>
+ if (flag != null && name != null && sep == null)
+ echo(s"""I assume you mean "$flag $name"?""")
+ (flag != null, Option(name))
+ case _ =>
+ echo("usage: :paste -raw file")
+ return result
+ }
+ }
+ val code = file match {
+ case Some(name) =>
+ withFile(name)(f => {
+ shouldReplay = Some(s":paste $arg")
+ val s = f.slurp.trim
+ if (s.isEmpty) echo(s"File contains no code: $f")
+ else echo(s"Pasting file $f...")
+ s
+ }) getOrElse ""
+ case None =>
+ echo("// Entering paste mode (ctrl-D to finish)\n")
+ val text = (readWhile(_ => true) mkString "\n").trim
+ if (text.isEmpty) echo("\n// Nothing pasted, nothing gained.\n")
+ else echo("\n// Exiting paste mode, now interpreting.\n")
+ text
+ }
+ def interpretCode() = {
+ val res = intp interpret code
+ // if input is incomplete, let the compiler try to say why
+ if (res == IR.Incomplete) {
+ echo("The pasted code is incomplete!\n")
+ // Remembrance of Things Pasted in an object
+ val errless = intp compileSources new BatchSourceFile("", s"object pastel {\n$code\n}")
+ if (errless) echo("...but compilation found no error? Good luck with that.")
+ }
+ }
+ def compileCode() = {
+ val errless = intp compileSources new BatchSourceFile("", code)
+ if (!errless) echo("There were compilation errors!")
+ }
+ if (code.nonEmpty) {
+ if (raw) compileCode() else interpretCode()
+ }
+ result
+ }
+
+ private object paste extends Pasted {
+ val ContinueString = " | "
+ val PromptString = "scala> "
+
+ def interpret(line: String): Unit = {
+ echo(line.trim)
+ intp interpret line
+ echo("")
+ }
+
+ def transcript(start: String) = {
+ echo("\n// Detected repl transcript paste: ctrl-D to finish.\n")
+ apply(Iterator(start) ++ readWhile(_.trim != PromptString.trim))
+ }
+ }
+ import paste.{ ContinueString, PromptString }
+
+ /** Interpret expressions starting with the first line.
+ * Read lines until a complete compilation unit is available
+ * or until a syntax error has been seen. If a full unit is
+ * read, go ahead and interpret it. Return the full string
+ * to be recorded for replay, if any.
+ */
+ def interpretStartingWith(code: String): Option[String] = {
+ // signal completion non-completion input has been received
+ in.completion.resetVerbosity()
+
+ def reallyInterpret = {
+ val reallyResult = intp.interpret(code)
+ (reallyResult, reallyResult match {
+ case IR.Error => None
+ case IR.Success => Some(code)
+ case IR.Incomplete =>
+ if (in.interactive && code.endsWith("\n\n")) {
+ echo("You typed two blank lines. Starting a new command.")
+ None
+ }
+ else in.readLine(ContinueString) match {
+ case null =>
+ // we know compilation is going to fail since we're at EOF and the
+ // parser thinks the input is still incomplete, but since this is
+ // a file being read non-interactively we want to fail. So we send
+ // it straight to the compiler for the nice error message.
+ intp.compileString(code)
+ None
+
+ case line => interpretStartingWith(code + "\n" + line)
+ }
+ })
+ }
+
+ /** Here we place ourselves between the user and the interpreter and examine
+ * the input they are ostensibly submitting. We intervene in several cases:
+ *
+ * 1) If the line starts with "scala> " it is assumed to be an interpreter paste.
+ * 2) If the line starts with "." (but not ".." or "./") it is treated as an invocation
+ * on the previous result.
+ * 3) If the Completion object's execute returns Some(_), we inject that value
+ * and avoid the interpreter, as it's likely not valid scala code.
+ */
+ if (code == "") None
+ else if (!paste.running && code.trim.startsWith(PromptString)) {
+ paste.transcript(code)
+ None
+ }
+ else if (Completion.looksLikeInvocation(code) && intp.mostRecentVar != "") {
+ interpretStartingWith(intp.mostRecentVar + code)
+ }
+ else if (code.trim startsWith "//") {
+ // line comment, do nothing
+ None
+ }
+ else
+ reallyInterpret._2
+ }
+
+ // runs :load `file` on any files passed via -i
+ def loadFiles(settings: Settings) = settings match {
+ case settings: GenericRunnerSettings =>
+ for (filename <- settings.loadfiles.value) {
+ val cmd = ":load " + filename
+ command(cmd)
+ addReplay(cmd)
+ echo("")
+ }
+ case _ =>
+ }
+
+ /** Tries to create a JLineReader, falling back to SimpleReader:
+ * unless settings or properties are such that it should start
+ * with SimpleReader.
+ */
+ def chooseReader(settings: Settings): InteractiveReader = {
+ if (settings.Xnojline || Properties.isEmacsShell)
+ SimpleReader()
+ else try new JLineReader(
+ if (settings.noCompletion) NoCompletion
+ else new SparkJLineCompletion(intp)
+ )
+ catch {
+ case ex @ (_: Exception | _: NoClassDefFoundError) =>
+ echo("Failed to created JLineReader: " + ex + "\nFalling back to SimpleReader.")
+ SimpleReader()
+ }
+ }
+ protected def tagOfStaticClass[T: ClassTag]: u.TypeTag[T] =
+ u.TypeTag[T](
+ m,
+ new TypeCreator {
+ def apply[U <: Universe with Singleton](m: Mirror[U]): U # Type =
+ m.staticClass(classTag[T].runtimeClass.getName).toTypeConstructor.asInstanceOf[U # Type]
+ })
+
+ private def loopPostInit() {
+ // Bind intp somewhere out of the regular namespace where
+ // we can get at it in generated code.
+ intp.quietBind(NamedParam[SparkIMain]("$intp", intp)(tagOfStaticClass[SparkIMain], classTag[SparkIMain]))
+ // Auto-run code via some setting.
+ ( replProps.replAutorunCode.option
+ flatMap (f => io.File(f).safeSlurp())
+ foreach (intp quietRun _)
+ )
+ // classloader and power mode setup
+ intp.setContextClassLoader()
+ if (isReplPower) {
+ // replProps.power setValue true
+ // unleashAndSetPhase()
+ // asyncMessage(power.banner)
+ }
+ // SI-7418 Now, and only now, can we enable TAB completion.
+ in match {
+ case x: JLineReader => x.consoleReader.postInit
+ case _ =>
+ }
+ }
+ def process(settings: Settings): Boolean = savingContextLoader {
+ this.settings = settings
+ createInterpreter()
+
+ // sets in to some kind of reader depending on environmental cues
+ in = in0.fold(chooseReader(settings))(r => SimpleReader(r, out, interactive = true))
+ globalFuture = future {
+ intp.initializeSynchronous()
+ loopPostInit()
+ !intp.reporter.hasErrors
+ }
+ import scala.concurrent.duration._
+ Await.ready(globalFuture, 10 seconds)
+ printWelcome()
+ initializeSpark()
+ loadFiles(settings)
+
+ try loop()
+ catch AbstractOrMissingHandler()
+ finally closeInterpreter()
+
+ true
+ }
+
+ @deprecated("Use `process` instead", "2.9.0")
+ def main(settings: Settings): Unit = process(settings) //used by sbt
+}
+
+object SparkILoop {
+ implicit def loopToInterpreter(repl: SparkILoop): SparkIMain = repl.intp
+
+ // Designed primarily for use by test code: take a String with a
+ // bunch of code, and prints out a transcript of what it would look
+ // like if you'd just typed it into the repl.
+ def runForTranscript(code: String, settings: Settings): String = {
+ import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
+
+ stringFromStream { ostream =>
+ Console.withOut(ostream) {
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true) {
+ override def write(str: String) = {
+ // completely skip continuation lines
+ if (str forall (ch => ch.isWhitespace || ch == '|')) ()
+ else super.write(str)
+ }
+ }
+ val input = new BufferedReader(new StringReader(code.trim + "\n")) {
+ override def readLine(): String = {
+ val s = super.readLine()
+ // helping out by printing the line being interpreted.
+ if (s != null)
+ output.println(s)
+ s
+ }
+ }
+ val repl = new SparkILoop(input, output)
+ if (settings.classpath.isDefault)
+ settings.classpath.value = sys.props("java.class.path")
+
+ repl process settings
+ }
+ }
+ }
+
+ /** Creates an interpreter loop with default settings and feeds
+ * the given code to it as input.
+ */
+ def run(code: String, sets: Settings = new Settings): String = {
+ import java.io.{ BufferedReader, StringReader, OutputStreamWriter }
+
+ stringFromStream { ostream =>
+ Console.withOut(ostream) {
+ val input = new BufferedReader(new StringReader(code))
+ val output = new JPrintWriter(new OutputStreamWriter(ostream), true)
+ val repl = new SparkILoop(input, output)
+
+ if (sets.classpath.isDefault)
+ sets.classpath.value = sys.props("java.class.path")
+
+ repl process sets
+ }
+ }
+ }
+ def run(lines: List[String]): String = run(lines map (_ + "\n") mkString)
+}
diff --git a/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
new file mode 100644
index 0000000000000..1bb62c84abddc
--- /dev/null
+++ b/repl/scala-2.11/src/main/scala/org/apache/spark/repl/SparkIMain.scala
@@ -0,0 +1,1319 @@
+/* NSC -- new Scala compiler
+ * Copyright 2005-2013 LAMP/EPFL
+ * @author Martin Odersky
+ */
+
+package scala
+package tools.nsc
+package interpreter
+
+import PartialFunction.cond
+import scala.language.implicitConversions
+import scala.beans.BeanProperty
+import scala.collection.mutable
+import scala.concurrent.{ Future, ExecutionContext }
+import scala.reflect.runtime.{ universe => ru }
+import scala.reflect.{ ClassTag, classTag }
+import scala.reflect.internal.util.{ BatchSourceFile, SourceFile }
+import scala.tools.util.PathResolver
+import scala.tools.nsc.io.AbstractFile
+import scala.tools.nsc.typechecker.{ TypeStrings, StructuredTypeStrings }
+import scala.tools.nsc.util.{ ScalaClassLoader, stringFromReader, stringFromWriter, StackTraceOps }
+import scala.tools.nsc.util.Exceptional.unwrap
+import javax.script.{AbstractScriptEngine, Bindings, ScriptContext, ScriptEngine, ScriptEngineFactory, ScriptException, CompiledScript, Compilable}
+
+/** An interpreter for Scala code.
+ *
+ * The main public entry points are compile(), interpret(), and bind().
+ * The compile() method loads a complete Scala file. The interpret() method
+ * executes one line of Scala code at the request of the user. The bind()
+ * method binds an object to a variable that can then be used by later
+ * interpreted code.
+ *
+ * The overall approach is based on compiling the requested code and then
+ * using a Java classloader and Java reflection to run the code
+ * and access its results.
+ *
+ * In more detail, a single compiler instance is used
+ * to accumulate all successfully compiled or interpreted Scala code. To
+ * "interpret" a line of code, the compiler generates a fresh object that
+ * includes the line of code and which has public member(s) to export
+ * all variables defined by that code. To extract the result of an
+ * interpreted line to show the user, a second "result object" is created
+ * which imports the variables exported by the above object and then
+ * exports members called "$eval" and "$print". To accomodate user expressions
+ * that read from variables or methods defined in previous statements, "import"
+ * statements are used.
+ *
+ * This interpreter shares the strengths and weaknesses of using the
+ * full compiler-to-Java. The main strength is that interpreted code
+ * behaves exactly as does compiled code, including running at full speed.
+ * The main weakness is that redefining classes and methods is not handled
+ * properly, because rebinding at the Java level is technically difficult.
+ *
+ * @author Moez A. Abdel-Gawad
+ * @author Lex Spoon
+ */
+class SparkIMain(@BeanProperty val factory: ScriptEngineFactory, initialSettings: Settings,
+ protected val out: JPrintWriter) extends AbstractScriptEngine with Compilable with SparkImports {
+ imain =>
+
+ setBindings(createBindings, ScriptContext.ENGINE_SCOPE)
+ object replOutput extends ReplOutput(settings.Yreploutdir) { }
+
+ @deprecated("Use replOutput.dir instead", "2.11.0")
+ def virtualDirectory = replOutput.dir
+ // Used in a test case.
+ def showDirectory() = replOutput.show(out)
+
+ private[nsc] var printResults = true // whether to print result lines
+ private[nsc] var totalSilence = false // whether to print anything
+ private var _initializeComplete = false // compiler is initialized
+ private var _isInitialized: Future[Boolean] = null // set up initialization future
+ private var bindExceptions = true // whether to bind the lastException variable
+ private var _executionWrapper = "" // code to be wrapped around all lines
+
+ /** We're going to go to some trouble to initialize the compiler asynchronously.
+ * It's critical that nothing call into it until it's been initialized or we will
+ * run into unrecoverable issues, but the perceived repl startup time goes
+ * through the roof if we wait for it. So we initialize it with a future and
+ * use a lazy val to ensure that any attempt to use the compiler object waits
+ * on the future.
+ */
+ private var _classLoader: util.AbstractFileClassLoader = null // active classloader
+ private val _compiler: ReplGlobal = newCompiler(settings, reporter) // our private compiler
+
+ def compilerClasspath: Seq[java.net.URL] = (
+ if (isInitializeComplete) global.classPath.asURLs
+ else new PathResolver(settings).result.asURLs // the compiler's classpath
+ )
+ def settings = initialSettings
+ // Run the code body with the given boolean settings flipped to true.
+ def withoutWarnings[T](body: => T): T = beQuietDuring {
+ val saved = settings.nowarn.value
+ if (!saved)
+ settings.nowarn.value = true
+
+ try body
+ finally if (!saved) settings.nowarn.value = false
+ }
+
+ /** construct an interpreter that reports to Console */
+ def this(settings: Settings, out: JPrintWriter) = this(null, settings, out)
+ def this(factory: ScriptEngineFactory, settings: Settings) = this(factory, settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this(settings: Settings) = this(settings, new NewLinePrintWriter(new ConsoleWriter, true))
+ def this(factory: ScriptEngineFactory) = this(factory, new Settings())
+ def this() = this(new Settings())
+
+ lazy val formatting: Formatting = new Formatting {
+ val prompt = Properties.shellPromptString
+ }
+ lazy val reporter: SparkReplReporter = new SparkReplReporter(this)
+
+ import formatting._
+ import reporter.{ printMessage, printUntruncatedMessage }
+
+ // This exists mostly because using the reporter too early leads to deadlock.
+ private def echo(msg: String) { Console println msg }
+ private def _initSources = List(new BatchSourceFile("", "class $repl_$init { }"))
+ private def _initialize() = {
+ try {
+ // if this crashes, REPL will hang its head in shame
+ val run = new _compiler.Run()
+ assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
+ run compileSources _initSources
+ _initializeComplete = true
+ true
+ }
+ catch AbstractOrMissingHandler()
+ }
+ private def tquoted(s: String) = "\"\"\"" + s + "\"\"\""
+ private val logScope = scala.sys.props contains "scala.repl.scope"
+ private def scopelog(msg: String) = if (logScope) Console.err.println(msg)
+
+ // argument is a thunk to execute after init is done
+ def initialize(postInitSignal: => Unit) {
+ synchronized {
+ if (_isInitialized == null) {
+ _isInitialized =
+ Future(try _initialize() finally postInitSignal)(ExecutionContext.global)
+ }
+ }
+ }
+ def initializeSynchronous(): Unit = {
+ if (!isInitializeComplete) {
+ _initialize()
+ assert(global != null, global)
+ }
+ }
+ def isInitializeComplete = _initializeComplete
+
+ lazy val global: Global = {
+ if (!isInitializeComplete) _initialize()
+ _compiler
+ }
+
+ import global._
+ import definitions.{ ObjectClass, termMember, dropNullaryMethod}
+
+ lazy val runtimeMirror = ru.runtimeMirror(classLoader)
+
+ private def noFatal(body: => Symbol): Symbol = try body catch { case _: FatalError => NoSymbol }
+
+ def getClassIfDefined(path: String) = (
+ noFatal(runtimeMirror staticClass path)
+ orElse noFatal(rootMirror staticClass path)
+ )
+ def getModuleIfDefined(path: String) = (
+ noFatal(runtimeMirror staticModule path)
+ orElse noFatal(rootMirror staticModule path)
+ )
+
+ implicit class ReplTypeOps(tp: Type) {
+ def andAlso(fn: Type => Type): Type = if (tp eq NoType) tp else fn(tp)
+ }
+
+ // TODO: If we try to make naming a lazy val, we run into big time
+ // scalac unhappiness with what look like cycles. It has not been easy to
+ // reduce, but name resolution clearly takes different paths.
+ object naming extends {
+ val global: imain.global.type = imain.global
+ } with Naming {
+ // make sure we don't overwrite their unwisely named res3 etc.
+ def freshUserTermName(): TermName = {
+ val name = newTermName(freshUserVarName())
+ if (replScope containsName name) freshUserTermName()
+ else name
+ }
+ def isInternalTermName(name: Name) = isInternalVarName("" + name)
+ }
+ import naming._
+
+ object deconstruct extends {
+ val global: imain.global.type = imain.global
+ } with StructuredTypeStrings
+
+ lazy val memberHandlers = new {
+ val intp: imain.type = imain
+ } with SparkMemberHandlers
+ import memberHandlers._
+
+ /** Temporarily be quiet */
+ def beQuietDuring[T](body: => T): T = {
+ val saved = printResults
+ printResults = false
+ try body
+ finally printResults = saved
+ }
+ def beSilentDuring[T](operation: => T): T = {
+ val saved = totalSilence
+ totalSilence = true
+ try operation
+ finally totalSilence = saved
+ }
+
+ def quietRun[T](code: String) = beQuietDuring(interpret(code))
+
+ /** takes AnyRef because it may be binding a Throwable or an Exceptional */
+ private def withLastExceptionLock[T](body: => T, alt: => T): T = {
+ assert(bindExceptions, "withLastExceptionLock called incorrectly.")
+ bindExceptions = false
+
+ try beQuietDuring(body)
+ catch logAndDiscard("withLastExceptionLock", alt)
+ finally bindExceptions = true
+ }
+
+ def executionWrapper = _executionWrapper
+ def setExecutionWrapper(code: String) = _executionWrapper = code
+ def clearExecutionWrapper() = _executionWrapper = ""
+
+ /** interpreter settings */
+ lazy val isettings = new SparkISettings(this)
+
+ /** Instantiate a compiler. Overridable. */
+ protected def newCompiler(settings: Settings, reporter: reporters.Reporter): ReplGlobal = {
+ settings.outputDirs setSingleOutput replOutput.dir
+ settings.exposeEmptyPackage.value = true
+ new Global(settings, reporter) with ReplGlobal { override def toString: String = "" }
+ }
+
+ /** Parent classloader. Overridable. */
+ protected def parentClassLoader: ClassLoader =
+ settings.explicitParentLoader.getOrElse( this.getClass.getClassLoader() )
+
+ /* A single class loader is used for all commands interpreted by this Interpreter.
+ It would also be possible to create a new class loader for each command
+ to interpret. The advantages of the current approach are:
+
+ - Expressions are only evaluated one time. This is especially
+ significant for I/O, e.g. "val x = Console.readLine"
+
+ The main disadvantage is:
+
+ - Objects, classes, and methods cannot be rebound. Instead, definitions
+ shadow the old ones, and old code objects refer to the old
+ definitions.
+ */
+ def resetClassLoader() = {
+ repldbg("Setting new classloader: was " + _classLoader)
+ _classLoader = null
+ ensureClassLoader()
+ }
+ final def ensureClassLoader() {
+ if (_classLoader == null)
+ _classLoader = makeClassLoader()
+ }
+ def classLoader: util.AbstractFileClassLoader = {
+ ensureClassLoader()
+ _classLoader
+ }
+
+ def backticked(s: String): String = (
+ (s split '.').toList map {
+ case "_" => "_"
+ case s if nme.keywords(newTermName(s)) => s"`$s`"
+ case s => s
+ } mkString "."
+ )
+ def readRootPath(readPath: String) = getModuleIfDefined(readPath)
+
+ abstract class PhaseDependentOps {
+ def shift[T](op: => T): T
+
+ def path(name: => Name): String = shift(path(symbolOfName(name)))
+ def path(sym: Symbol): String = backticked(shift(sym.fullName))
+ def sig(sym: Symbol): String = shift(sym.defString)
+ }
+ object typerOp extends PhaseDependentOps {
+ def shift[T](op: => T): T = exitingTyper(op)
+ }
+ object flatOp extends PhaseDependentOps {
+ def shift[T](op: => T): T = exitingFlatten(op)
+ }
+
+ def originalPath(name: String): String = originalPath(name: TermName)
+ def originalPath(name: Name): String = typerOp path name
+ def originalPath(sym: Symbol): String = typerOp path sym
+ def flatPath(sym: Symbol): String = flatOp shift sym.javaClassName
+ def translatePath(path: String) = {
+ val sym = if (path endsWith "$") symbolOfTerm(path.init) else symbolOfIdent(path)
+ sym.toOption map flatPath
+ }
+ def translateEnclosingClass(n: String) = symbolOfTerm(n).enclClass.toOption map flatPath
+
+ private class TranslatingClassLoader(parent: ClassLoader) extends util.AbstractFileClassLoader(replOutput.dir, parent) {
+ /** Overridden here to try translating a simple name to the generated
+ * class name if the original attempt fails. This method is used by
+ * getResourceAsStream as well as findClass.
+ */
+ override protected def findAbstractFile(name: String): AbstractFile =
+ super.findAbstractFile(name) match {
+ case null if _initializeComplete => translatePath(name) map (super.findAbstractFile(_)) orNull
+ case file => file
+ }
+ }
+ private def makeClassLoader(): util.AbstractFileClassLoader =
+ new TranslatingClassLoader(parentClassLoader match {
+ case null => ScalaClassLoader fromURLs compilerClasspath
+ case p => new ScalaClassLoader.URLClassLoader(compilerClasspath, p)
+ })
+
+ // Set the current Java "context" class loader to this interpreter's class loader
+ def setContextClassLoader() = classLoader.setAsContext()
+
+ def allDefinedNames: List[Name] = exitingTyper(replScope.toList.map(_.name).sorted)
+ def unqualifiedIds: List[String] = allDefinedNames map (_.decode) sorted
+
+ /** Most recent tree handled which wasn't wholly synthetic. */
+ private def mostRecentlyHandledTree: Option[Tree] = {
+ prevRequests.reverse foreach { req =>
+ req.handlers.reverse foreach {
+ case x: MemberDefHandler if x.definesValue && !isInternalTermName(x.name) => return Some(x.member)
+ case _ => ()
+ }
+ }
+ None
+ }
+
+ private def updateReplScope(sym: Symbol, isDefined: Boolean) {
+ def log(what: String) {
+ val mark = if (sym.isType) "t " else "v "
+ val name = exitingTyper(sym.nameString)
+ val info = cleanTypeAfterTyper(sym)
+ val defn = sym defStringSeenAs info
+
+ scopelog(f"[$mark$what%6s] $name%-25s $defn%s")
+ }
+ if (ObjectClass isSubClass sym.owner) return
+ // unlink previous
+ replScope lookupAll sym.name foreach { sym =>
+ log("unlink")
+ replScope unlink sym
+ }
+ val what = if (isDefined) "define" else "import"
+ log(what)
+ replScope enter sym
+ }
+
+ def recordRequest(req: Request) {
+ if (req == null)
+ return
+
+ prevRequests += req
+
+ // warning about serially defining companions. It'd be easy
+ // enough to just redefine them together but that may not always
+ // be what people want so I'm waiting until I can do it better.
+ exitingTyper {
+ req.defines filterNot (s => req.defines contains s.companionSymbol) foreach { newSym =>
+ val oldSym = replScope lookup newSym.name.companionName
+ if (Seq(oldSym, newSym).permutations exists { case Seq(s1, s2) => s1.isClass && s2.isModule }) {
+ replwarn(s"warning: previously defined $oldSym is not a companion to $newSym.")
+ replwarn("Companions must be defined together; you may wish to use :paste mode for this.")
+ }
+ }
+ }
+ exitingTyper {
+ req.imports foreach (sym => updateReplScope(sym, isDefined = false))
+ req.defines foreach (sym => updateReplScope(sym, isDefined = true))
+ }
+ }
+
+ private[nsc] def replwarn(msg: => String) {
+ if (!settings.nowarnings)
+ printMessage(msg)
+ }
+
+ def compileSourcesKeepingRun(sources: SourceFile*) = {
+ val run = new Run()
+ assert(run.typerPhase != NoPhase, "REPL requires a typer phase.")
+ reporter.reset()
+ run compileSources sources.toList
+ (!reporter.hasErrors, run)
+ }
+
+ /** Compile an nsc SourceFile. Returns true if there are
+ * no compilation errors, or false otherwise.
+ */
+ def compileSources(sources: SourceFile*): Boolean =
+ compileSourcesKeepingRun(sources: _*)._1
+
+ /** Compile a string. Returns true if there are no
+ * compilation errors, or false otherwise.
+ */
+ def compileString(code: String): Boolean =
+ compileSources(new BatchSourceFile("