diff --git a/core/src/main/scala/org/apache/spark/api/python/Py4JServer.scala b/core/src/main/scala/org/apache/spark/api/python/Py4JServer.scala new file mode 100644 index 000000000000..db440b117892 --- /dev/null +++ b/core/src/main/scala/org/apache/spark/api/python/Py4JServer.scala @@ -0,0 +1,69 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.api.python + +import java.net.InetAddress +import java.util.Locale + +import org.apache.spark.SparkConf +import org.apache.spark.internal.Logging +import org.apache.spark.util.Utils + +/** + * A wrapper for both GatewayServer, and ClientServer to pin Python thread to JVM thread. + */ +private[spark] class Py4JServer(sparkConf: SparkConf) extends Logging { + private[spark] val secret: String = Utils.createSecret(sparkConf) + + // Launch a Py4J gateway or client server for the process to connect to; this will let it see our + // Java system properties and such + private val localhost = InetAddress.getLoopbackAddress() + private[spark] val server = if (sys.env.getOrElse( + "PYSPARK_PIN_THREAD", "false").toLowerCase(Locale.ROOT) == "true") { + new py4j.ClientServer.ClientServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .build() + } else { + new py4j.GatewayServer.GatewayServerBuilder() + .authToken(secret) + .javaPort(0) + .javaAddress(localhost) + .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) + .build() + } + + def start(): Unit = server match { + case clientServer: py4j.ClientServer => clientServer.startServer() + case gatewayServer: py4j.GatewayServer => gatewayServer.start() + case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}") + } + + def getListeningPort: Int = server match { + case clientServer: py4j.ClientServer => clientServer.getJavaServer.getListeningPort + case gatewayServer: py4j.GatewayServer => gatewayServer.getListeningPort + case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}") + } + + def shutdown(): Unit = server match { + case clientServer: py4j.ClientServer => clientServer.shutdown() + case gatewayServer: py4j.GatewayServer => gatewayServer.shutdown() + case other => throw new RuntimeException(s"Unexpected Py4J server ${other.getClass}") + } +} diff --git a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala index 9ddc4a491018..ed70e26e2520 100644 --- a/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala +++ b/core/src/main/scala/org/apache/spark/api/python/PythonGatewayServer.scala @@ -18,18 +18,14 @@ package org.apache.spark.api.python import java.io.{DataOutputStream, File, FileOutputStream} -import java.net.InetAddress import java.nio.charset.StandardCharsets.UTF_8 import java.nio.file.Files -import py4j.GatewayServer - import org.apache.spark.SparkConf import org.apache.spark.internal.Logging -import org.apache.spark.util.Utils /** - * Process that starts a Py4J GatewayServer on an ephemeral port. + * Process that starts a Py4J server on an ephemeral port. * * This process is launched (via SparkSubmit) by the PySpark driver (see java_gateway.py). */ @@ -37,23 +33,13 @@ private[spark] object PythonGatewayServer extends Logging { initializeLogIfNecessary(true) def main(args: Array[String]): Unit = { - val secret = Utils.createSecret(new SparkConf()) - - // Start a GatewayServer on an ephemeral port. Make sure the callback client is configured - // with the same secret, in case the app needs callbacks from the JVM to the underlying - // python processes. - val localhost = InetAddress.getLoopbackAddress() - val gatewayServer: GatewayServer = new GatewayServer.GatewayServerBuilder() - .authToken(secret) - .javaPort(0) - .javaAddress(localhost) - .callbackClient(GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) - .build() + val sparkConf = new SparkConf() + val gatewayServer: Py4JServer = new Py4JServer(sparkConf) gatewayServer.start() val boundPort: Int = gatewayServer.getListeningPort if (boundPort == -1) { - logError("GatewayServer failed to bind; exiting") + logError(s"${gatewayServer.server.getClass} failed to bind; exiting") System.exit(1) } else { logDebug(s"Started PythonGatewayServer on port $boundPort") @@ -68,7 +54,7 @@ private[spark] object PythonGatewayServer extends Logging { val dos = new DataOutputStream(new FileOutputStream(tmpPath)) dos.writeInt(boundPort) - val secretBytes = secret.getBytes(UTF_8) + val secretBytes = gatewayServer.secret.getBytes(UTF_8) dos.writeInt(secretBytes.length) dos.write(secretBytes, 0, secretBytes.length) dos.close() diff --git a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala index 0c9d34986af6..574ce60b19b4 100644 --- a/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala +++ b/core/src/main/scala/org/apache/spark/deploy/PythonRunner.scala @@ -18,7 +18,7 @@ package org.apache.spark.deploy import java.io.File -import java.net.{InetAddress, URI} +import java.net.URI import java.nio.file.Files import scala.collection.JavaConverters._ @@ -26,7 +26,7 @@ import scala.collection.mutable.ArrayBuffer import scala.util.Try import org.apache.spark.{SparkConf, SparkUserAppException} -import org.apache.spark.api.python.PythonUtils +import org.apache.spark.api.python.{Py4JServer, PythonUtils} import org.apache.spark.internal.config._ import org.apache.spark.util.{RedirectThread, Utils} @@ -40,7 +40,6 @@ object PythonRunner { val pyFiles = args(1) val otherArgs = args.slice(2, args.length) val sparkConf = new SparkConf() - val secret = Utils.createSecret(sparkConf) val pythonExec = sparkConf.get(PYSPARK_DRIVER_PYTHON) .orElse(sparkConf.get(PYSPARK_PYTHON)) .orElse(sys.env.get("PYSPARK_DRIVER_PYTHON")) @@ -51,15 +50,8 @@ object PythonRunner { val formattedPythonFile = formatPath(pythonFile) val formattedPyFiles = resolvePyFiles(formatPaths(pyFiles)) - // Launch a Py4J gateway server for the process to connect to; this will let it see our - // Java system properties and such - val localhost = InetAddress.getLoopbackAddress() - val gatewayServer = new py4j.GatewayServer.GatewayServerBuilder() - .authToken(secret) - .javaPort(0) - .javaAddress(localhost) - .callbackClient(py4j.GatewayServer.DEFAULT_PYTHON_PORT, localhost, secret) - .build() + val gatewayServer = new Py4JServer(sparkConf) + val thread = new Thread(() => Utils.logUncaughtExceptions { gatewayServer.start() }) thread.setName("py4j-gateway-init") thread.setDaemon(true) @@ -86,7 +78,7 @@ object PythonRunner { // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: env.put("PYTHONUNBUFFERED", "YES") // value is needed to be set to a non-empty string env.put("PYSPARK_GATEWAY_PORT", "" + gatewayServer.getListeningPort) - env.put("PYSPARK_GATEWAY_SECRET", secret) + env.put("PYSPARK_GATEWAY_SECRET", gatewayServer.secret) // pass conf spark.pyspark.python to python process, the only way to pass info to // python process is through environment variable. sparkConf.get(PYSPARK_PYTHON).foreach(env.put("PYSPARK_PYTHON", _)) diff --git a/docs/job-scheduling.md b/docs/job-scheduling.md index 3e70c59d89a3..eaacfa49c657 100644 --- a/docs/job-scheduling.md +++ b/docs/job-scheduling.md @@ -287,3 +287,21 @@ users can set the `spark.sql.thriftserver.scheduler.pool` variable: {% highlight SQL %} SET spark.sql.thriftserver.scheduler.pool=accounting; {% endhighlight %} + +## Concurrent Jobs in PySpark + +PySpark, by default, does not support to synchronize PVM threads with JVM threads and +launching multiple jobs in multiple PVM threads does not guarantee to launch each job +in each corresponding JVM thread. Due to this limitation, it is unable to set a different job group +via `sc.setJobGroup` in a separate PVM thread, which also disallows to cancel the job via `sc.cancelJobGroup` +later. + +In order to synchronize PVM threads with JVM threads, you should set `PYSPARK_PIN_THREAD` environment variable +to `true`. This pinned thread mode allows one PVM thread has one corresponding JVM thread. + +However, currently it cannot inherit the local properties from the parent thread although it isolates +each thread with its own local properties. To work around this, you should manually copy and set the +local properties from the parent thread to the child thread when you create another thread in PVM. + +Note that `PYSPARK_PIN_THREAD` is currently experimental and not recommended for use in production. + diff --git a/python/pyspark/context.py b/python/pyspark/context.py index bf96fba90b2c..e7e7bcd95a06 100644 --- a/python/pyspark/context.py +++ b/python/pyspark/context.py @@ -1009,14 +1009,61 @@ def setJobGroup(self, groupId, description, interruptOnCancel=False): in Thread.interrupt() being called on the job's executor threads. This is useful to help ensure that the tasks are actually stopped in a timely manner, but is off by default due to HDFS-1208, where HDFS may respond to Thread.interrupt() by marking nodes as dead. - """ + + .. note:: Currently, setting a group ID (set to local properties) with a thread does + not properly work. Internally threads on PVM and JVM are not synced, and JVM thread + can be reused for multiple threads on PVM, which fails to isolate local properties + for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to + `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties + from the parent thread although it isolates each thread on PVM and JVM with its own + local properties. To work around this, you should manually copy and set the local + properties from the parent thread to the child thread when you create another thread. + """ + warnings.warn( + "Currently, setting a group ID (set to local properties) with a thread does " + "not properly work. " + "\n" + "Internally threads on PVM and JVM are not synced, and JVM thread can be reused " + "for multiple threads on PVM, which fails to isolate local properties for each " + "thread on PVM. " + "\n" + "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). " + "However, note that it cannot inherit the local properties from the parent thread " + "although it isolates each thread on PVM and JVM with its own local properties. " + "\n" + "To work around this, you should manually copy and set the local properties from " + "the parent thread to the child thread when you create another thread.", + UserWarning) self._jsc.setJobGroup(groupId, description, interruptOnCancel) def setLocalProperty(self, key, value): """ Set a local property that affects jobs submitted from this thread, such as the Spark fair scheduler pool. - """ + + .. note:: Currently, setting a local property with a thread does + not properly work. Internally threads on PVM and JVM are not synced, and JVM thread + can be reused for multiple threads on PVM, which fails to isolate local properties + for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to + `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties + from the parent thread although it isolates each thread on PVM and JVM with its own + local properties. To work around this, you should manually copy and set the local + properties from the parent thread to the child thread when you create another thread. + """ + warnings.warn( + "Currently, setting a local property with a thread does not properly work. " + "\n" + "Internally threads on PVM and JVM are not synced, and JVM thread can be reused " + "for multiple threads on PVM, which fails to isolate local properties for each " + "thread on PVM. " + "\n" + "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). " + "However, note that it cannot inherit the local properties from the parent thread " + "although it isolates each thread on PVM and JVM with its own local properties. " + "\n" + "To work around this, you should manually copy and set the local properties from " + "the parent thread to the child thread when you create another thread.", + UserWarning) self._jsc.setLocalProperty(key, value) def getLocalProperty(self, key): @@ -1029,7 +1076,31 @@ def getLocalProperty(self, key): def setJobDescription(self, value): """ Set a human readable description of the current job. - """ + + .. note:: Currently, setting a job description (set to local properties) with a thread does + not properly work. Internally threads on PVM and JVM are not synced, and JVM thread + can be reused for multiple threads on PVM, which fails to isolate local properties + for each thread on PVM. To work around this, you can set `PYSPARK_PIN_THREAD` to + `'true'` (see SPARK-22340). However, note that it cannot inherit the local properties + from the parent thread although it isolates each thread on PVM and JVM with its own + local properties. To work around this, you should manually copy and set the local + properties from the parent thread to the child thread when you create another thread. + """ + warnings.warn( + "Currently, setting a job description (set to local properties) with a thread does " + "not properly work. " + "\n" + "Internally threads on PVM and JVM are not synced, and JVM thread can be reused " + "for multiple threads on PVM, which fails to isolate local properties for each " + "thread on PVM. " + "\n" + "To work around this, you can set PYSPARK_PIN_THREAD to true (see SPARK-22340). " + "However, note that it cannot inherit the local properties from the parent thread " + "although it isolates each thread on PVM and JVM with its own local properties. " + "\n" + "To work around this, you should manually copy and set the local properties from " + "the parent thread to the child thread when you create another thread.", + UserWarning) self._jsc.setJobDescription(value) def sparkUser(self): diff --git a/python/pyspark/java_gateway.py b/python/pyspark/java_gateway.py index b09bd01638d8..316a5b4d0127 100644 --- a/python/pyspark/java_gateway.py +++ b/python/pyspark/java_gateway.py @@ -31,6 +31,7 @@ xrange = range from py4j.java_gateway import java_import, JavaGateway, JavaObject, GatewayParameters +from py4j.clientserver import ClientServer, JavaParameters, PythonParameters from pyspark.find_spark_home import _find_spark_home from pyspark.serializers import read_int, write_with_length, UTF8Deserializer from pyspark.util import _exception_message @@ -125,10 +126,23 @@ def killChild(): Popen(["cmd", "/c", "taskkill", "/f", "/t", "/pid", str(proc.pid)]) atexit.register(killChild) - # Connect to the gateway - gateway = JavaGateway( - gateway_parameters=GatewayParameters(port=gateway_port, auth_token=gateway_secret, - auto_convert=True)) + # Connect to the gateway (or client server to pin the thread between JVM and Python) + if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true": + gateway = ClientServer( + java_parameters=JavaParameters( + port=gateway_port, + auth_token=gateway_secret, + auto_convert=True), + python_parameters=PythonParameters( + port=0, + eager_load=False)) + else: + gateway = JavaGateway( + gateway_parameters=GatewayParameters( + port=gateway_port, + auth_token=gateway_secret, + auto_convert=True)) + # Store a reference to the Popen object for use by the caller (e.g., in reading stdout/stderr) gateway.proc = proc diff --git a/python/pyspark/ml/tests/test_wrapper.py b/python/pyspark/ml/tests/test_wrapper.py index 09456d8e97a4..c0747155cb72 100644 --- a/python/pyspark/ml/tests/test_wrapper.py +++ b/python/pyspark/ml/tests/test_wrapper.py @@ -24,6 +24,7 @@ from pyspark.ml.wrapper import _java2py, _py2java, JavaParams, JavaWrapper from pyspark.testing.mllibutils import MLlibTestCase from pyspark.testing.mlutils import SparkSessionTestCase +from pyspark.testing.utils import eventually class JavaWrapperMemoryTests(SparkSessionTestCase): @@ -50,19 +51,27 @@ def test_java_object_gets_detached(self): model.__del__() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + def condition(): + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + self.assertIn("LinearRegressionTrainingSummary", summary._java_obj.toString()) + return True + + eventually(condition, timeout=10, catch_assertions=True) try: summary.__del__() except: pass - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - model._java_obj.toString() - with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): - summary._java_obj.toString() + def condition(): + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + model._java_obj.toString() + with self.assertRaisesRegexp(py4j.protocol.Py4JError, error_no_object): + summary._java_obj.toString() + return True + + eventually(condition, timeout=10, catch_assertions=True) class WrapperTests(MLlibTestCase): diff --git a/python/pyspark/mllib/tests/test_streaming_algorithms.py b/python/pyspark/mllib/tests/test_streaming_algorithms.py index 6f098f458293..8056debb963b 100644 --- a/python/pyspark/mllib/tests/test_streaming_algorithms.py +++ b/python/pyspark/mllib/tests/test_streaming_algorithms.py @@ -28,6 +28,7 @@ from pyspark.mllib.regression import LabeledPoint, StreamingLinearRegressionWithSGD from pyspark.mllib.util import LinearDataGenerator from pyspark.streaming import StreamingContext +from pyspark.testing.utils import eventually class MLLibStreamingTestCase(unittest.TestCase): @@ -39,44 +40,6 @@ def tearDown(self): self.ssc.stop(False) self.sc.stop() - @staticmethod - def _eventually(condition, timeout=30.0, catch_assertions=False): - """ - Wait a given amount of time for a condition to pass, else fail with an error. - This is a helper utility for streaming ML tests. - :param condition: Function that checks for termination conditions. - condition() can return: - - True: Conditions met. Return without error. - - other value: Conditions not met yet. Continue. Upon timeout, - include last such value in error message. - Note that this method may be called at any time during - streaming execution (e.g., even before any results - have been created). - :param timeout: Number of seconds to wait. Default 30 seconds. - :param catch_assertions: If False (default), do not catch AssertionErrors. - If True, catch AssertionErrors; continue, but save - error to throw upon timeout. - """ - start_time = time() - lastValue = None - while time() - start_time < timeout: - if catch_assertions: - try: - lastValue = condition() - except AssertionError as e: - lastValue = e - else: - lastValue = condition() - if lastValue is True: - return - sleep(0.01) - if isinstance(lastValue, AssertionError): - raise lastValue - else: - raise AssertionError( - "Test failed due to timeout after %g sec, with last condition returning: %s" - % (timeout, lastValue)) - class StreamingKMeansTest(MLLibStreamingTestCase): def test_model_params(self): @@ -111,7 +74,7 @@ def test_accuracy_for_single_center(self): def condition(): self.assertEqual(stkm.latestModel().clusterWeights, [25.0]) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) realCenters = array_sum(array(centers), axis=0) for i in range(5): @@ -155,7 +118,7 @@ def condition(): self.assertTrue(all(finalModel.centers == array(initCenters))) self.assertEqual(finalModel.clusterWeights, [5.0, 5.0, 5.0, 5.0]) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) def test_predictOn_model(self): """Test that the model predicts correctly on toy data.""" @@ -183,7 +146,7 @@ def condition(): self.assertEqual(result, [[0], [1], [2], [3]]) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) @unittest.skip("SPARK-10086: Flaky StreamingKMeans test in PySpark") def test_trainOn_predictOn(self): @@ -216,7 +179,7 @@ def condition(): self.assertEqual(predict_results, [[0, 1, 1], [1, 0, 1]]) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) class StreamingLogisticRegressionWithSGDTests(MLLibStreamingTestCase): @@ -263,7 +226,7 @@ def condition(): self.assertAlmostEqual(rel, 0.1, 1) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) def test_convergence(self): """ @@ -289,7 +252,7 @@ def condition(): return True # We want all batches to finish for this test. - self._eventually(condition, 60.0, catch_assertions=True) + eventually(condition, 60.0, catch_assertions=True) t_models = array(models) diff = t_models[1:] - t_models[:-1] @@ -322,7 +285,7 @@ def condition(): self.assertEqual(len(true_predicted), len(input_batches)) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) # Test that the accuracy error is no more than 0.4 on each batch. for batch in true_predicted: @@ -364,7 +327,7 @@ def condition(): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - self._eventually(condition, timeout=60.0) + eventually(condition, timeout=60.0) class StreamingLinearRegressionWithTests(MLLibStreamingTestCase): @@ -400,7 +363,7 @@ def condition(): self.assertAlmostEqual(slr.latestModel().intercept, 0.0, 1) return True - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) def test_parameter_convergence(self): """Test that the model parameters improve with streaming data.""" @@ -426,7 +389,7 @@ def condition(): return True # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) w = array(model_weights) diff = w[1:] - w[:-1] @@ -459,7 +422,7 @@ def condition(): return True # We want all batches to finish for this test. - self._eventually(condition, catch_assertions=True) + eventually(condition, catch_assertions=True) # Test that mean absolute error on each batch is less than 0.1 for batch in samples: @@ -500,7 +463,7 @@ def condition(): return True return "Latest errors: " + ", ".join(map(lambda x: str(x), errors)) - self._eventually(condition) + eventually(condition) if __name__ == "__main__": diff --git a/python/pyspark/testing/utils.py b/python/pyspark/testing/utils.py index 2b42b898f9ed..cda902b6f44d 100644 --- a/python/pyspark/testing/utils.py +++ b/python/pyspark/testing/utils.py @@ -19,6 +19,7 @@ import struct import sys import unittest +from time import time, sleep from pyspark import SparkContext, SparkConf @@ -50,6 +51,45 @@ def write_int(i): return struct.pack("!i", i) +def eventually(condition, timeout=30.0, catch_assertions=False): + """ + Wait a given amount of time for a condition to pass, else fail with an error. + This is a helper utility for PySpark tests. + + :param condition: Function that checks for termination conditions. + condition() can return: + - True: Conditions met. Return without error. + - other value: Conditions not met yet. Continue. Upon timeout, + include last such value in error message. + Note that this method may be called at any time during + streaming execution (e.g., even before any results + have been created). + :param timeout: Number of seconds to wait. Default 30 seconds. + :param catch_assertions: If False (default), do not catch AssertionErrors. + If True, catch AssertionErrors; continue, but save + error to throw upon timeout. + """ + start_time = time() + lastValue = None + while time() - start_time < timeout: + if catch_assertions: + try: + lastValue = condition() + except AssertionError as e: + lastValue = e + else: + lastValue = condition() + if lastValue is True: + return + sleep(0.01) + if isinstance(lastValue, AssertionError): + raise lastValue + else: + raise AssertionError( + "Test failed due to timeout after %g sec, with last condition returning: %s" + % (timeout, lastValue)) + + class QuietTest(object): def __init__(self, sc): self.log4j = sc._jvm.org.apache.log4j diff --git a/python/pyspark/tests/test_context.py b/python/pyspark/tests/test_context.py index 3f3150b0bd4e..c7f435a58221 100644 --- a/python/pyspark/tests/test_context.py +++ b/python/pyspark/tests/test_context.py @@ -214,6 +214,10 @@ def test_progress_api(self): rdd = sc.parallelize(range(10)).map(lambda x: time.sleep(100)) def run(): + # When thread is pinned, job group should be set for each thread for now. + # Local properties seem not being inherited like Scala side does. + if os.environ.get("PYSPARK_PIN_THREAD", "false").lower() == "true": + sc.setJobGroup('test_progress_api', '', True) try: rdd.count() except Exception: diff --git a/python/pyspark/tests/test_pin_thread.py b/python/pyspark/tests/test_pin_thread.py new file mode 100644 index 000000000000..657d129fe63b --- /dev/null +++ b/python/pyspark/tests/test_pin_thread.py @@ -0,0 +1,156 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +import os +import time +import random +import threading +import unittest + +from pyspark import SparkContext, SparkConf + + +class PinThreadTests(unittest.TestCase): + # These tests are in a separate class because it uses + # 'PYSPARK_PIN_THREAD' environment variable to test thread pin feature. + + @classmethod + def setUpClass(cls): + cls.old_pin_thread = os.environ.get("PYSPARK_PIN_THREAD") + os.environ["PYSPARK_PIN_THREAD"] = "true" + cls.sc = SparkContext('local[4]', cls.__name__, conf=SparkConf()) + + @classmethod + def tearDownClass(cls): + cls.sc.stop() + if cls.old_pin_thread is not None: + os.environ["PYSPARK_PIN_THREAD"] = cls.old_pin_thread + else: + del os.environ["PYSPARK_PIN_THREAD"] + + def test_pinned_thread(self): + threads = [] + exceptions = [] + property_name = "test_property_%s" % PinThreadTests.__name__ + jvm_thread_ids = [] + + for i in range(10): + def test_local_property(): + jvm_thread_id = self.sc._jvm.java.lang.Thread.currentThread().getId() + jvm_thread_ids.append(jvm_thread_id) + + # If a property is set in this thread, later it should get the same property + # within this thread. + self.sc.setLocalProperty(property_name, str(i)) + + # 5 threads, 1 second sleep. 5 threads without a sleep. + time.sleep(i % 2) + + try: + assert self.sc.getLocalProperty(property_name) == str(i) + + # Each command might create a thread in multi-threading mode in Py4J. + # This assert makes sure that the created thread is being reused. + assert jvm_thread_id == self.sc._jvm.java.lang.Thread.currentThread().getId() + except Exception as e: + exceptions.append(e) + threads.append(threading.Thread(target=test_local_property)) + + for t in threads: + t.start() + + for t in threads: + t.join() + + for e in exceptions: + raise e + + # Created JVM threads should be 10 because Python thread are 10. + assert len(set(jvm_thread_ids)) == 10 + + def test_multiple_group_jobs(self): + # SPARK-22340 Add a mode to pin Python thread into JVM's + + group_a = "job_ids_to_cancel" + group_b = "job_ids_to_run" + + threads = [] + thread_ids = range(4) + thread_ids_to_cancel = [i for i in thread_ids if i % 2 == 0] + thread_ids_to_run = [i for i in thread_ids if i % 2 != 0] + + # A list which records whether job is cancelled. + # The index of the array is the thread index which job run in. + is_job_cancelled = [False for _ in thread_ids] + + def run_job(job_group, index): + """ + Executes a job with the group ``job_group``. Each job waits for 3 seconds + and then exits. + """ + try: + self.sc.setJobGroup(job_group, "test rdd collect with setting job group") + self.sc.parallelize([15]).map(lambda x: time.sleep(x)).collect() + is_job_cancelled[index] = False + except Exception: + # Assume that exception means job cancellation. + is_job_cancelled[index] = True + + # Test if job succeeded when not cancelled. + run_job(group_a, 0) + self.assertFalse(is_job_cancelled[0]) + + # Run jobs + for i in thread_ids_to_cancel: + t = threading.Thread(target=run_job, args=(group_a, i)) + t.start() + threads.append(t) + + for i in thread_ids_to_run: + t = threading.Thread(target=run_job, args=(group_b, i)) + t.start() + threads.append(t) + + # Wait to make sure all jobs are executed. + time.sleep(3) + # And then, cancel one job group. + self.sc.cancelJobGroup(group_a) + + # Wait until all threads launching jobs are finished. + for t in threads: + t.join() + + for i in thread_ids_to_cancel: + self.assertTrue( + is_job_cancelled[i], + "Thread {i}: Job in group A was not cancelled.".format(i=i)) + + for i in thread_ids_to_run: + self.assertFalse( + is_job_cancelled[i], + "Thread {i}: Job in group B did not succeeded.".format(i=i)) + + +if __name__ == "__main__": + import unittest + from pyspark.tests.test_pin_thread import * + + try: + import xmlrunner + testRunner = xmlrunner.XMLTestRunner(output='target/test-reports', verbosity=2) + except ImportError: + testRunner = None + unittest.main(testRunner=testRunner, verbosity=2)