-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-23961][SPARK-27548][PYTHON] Fix error when toLocalIterator goes out of scope and properly raise errors from worker #24070
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
899ad8d
8c309c5
866d585
9ad3a77
3415ff1
57d251c
600a906
0a796d7
7847a14
a1f811a
29b8ab6
4f842dc
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -163,8 +163,63 @@ private[spark] object PythonRDD extends Logging { | |
| serveIterator(rdd.collect().iterator, s"serve RDD ${rdd.id}") | ||
| } | ||
|
|
||
| /** | ||
| * A helper function to create a local RDD iterator and serve it via socket. Partitions are | ||
| * are collected as separate jobs, by order of index. Partition data is first requested by a | ||
| * non-zero integer to start a collection job. The response is prefaced by an integer with 1 | ||
| * meaning partition data will be served, 0 meaning the local iterator has been consumed, | ||
| * and -1 meaining an error occurred during collection. This function is used by | ||
| * pyspark.rdd._local_iterator_from_socket(). | ||
| * | ||
| * @return 2-tuple (as a Java array) with the port number of a local socket which serves the | ||
| * data collected from these jobs, and the secret for authentication. | ||
| */ | ||
| def toLocalIteratorAndServe[T](rdd: RDD[T]): Array[Any] = { | ||
| serveIterator(rdd.toLocalIterator, s"serve toLocalIterator") | ||
| val (port, secret) = SocketAuthServer.setupOneConnectionServer( | ||
| authHelper, "serve toLocalIterator") { s => | ||
| val out = new DataOutputStream(s.getOutputStream) | ||
| val in = new DataInputStream(s.getInputStream) | ||
| Utils.tryWithSafeFinally { | ||
|
|
||
| // Collects a partition on each iteration | ||
| val collectPartitionIter = rdd.partitions.indices.iterator.map { i => | ||
| rdd.sparkContext.runJob(rdd, (iter: Iterator[Any]) => iter.toArray, Seq(i)).head | ||
|
||
| } | ||
|
|
||
| // Read request for data and send next partition if nonzero | ||
| var complete = false | ||
| while (!complete && in.readInt() != 0) { | ||
| if (collectPartitionIter.hasNext) { | ||
| try { | ||
| // Attempt to collect the next partition | ||
| val partitionArray = collectPartitionIter.next() | ||
|
|
||
| // Send response there is a partition to read | ||
| out.writeInt(1) | ||
|
|
||
| // Write the next object and signal end of data for this iteration | ||
| writeIteratorToStream(partitionArray.toIterator, out) | ||
| out.writeInt(SpecialLengths.END_OF_DATA_SECTION) | ||
| out.flush() | ||
| } catch { | ||
| case e: SparkException => | ||
|
||
| // Send response that an error occurred followed by error message | ||
| out.writeInt(-1) | ||
| writeUTF(e.getMessage, out) | ||
| complete = true | ||
| } | ||
| } else { | ||
| // Send response there are no more partitions to read and close | ||
| out.writeInt(0) | ||
| complete = true | ||
| } | ||
| } | ||
| } { | ||
| out.close() | ||
| in.close() | ||
| } | ||
| } | ||
| Array(port, secret) | ||
| } | ||
|
|
||
| def readRDDFromFile( | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -39,9 +39,9 @@ | |
| from itertools import imap as map, ifilter as filter | ||
|
|
||
| from pyspark.java_gateway import local_connect_and_auth | ||
| from pyspark.serializers import NoOpSerializer, CartesianDeserializer, \ | ||
| BatchedSerializer, CloudPickleSerializer, PairDeserializer, \ | ||
| PickleSerializer, pack_long, AutoBatchedSerializer | ||
| from pyspark.serializers import AutoBatchedSerializer, BatchedSerializer, NoOpSerializer, \ | ||
| CartesianDeserializer, CloudPickleSerializer, PairDeserializer, PickleSerializer, \ | ||
| UTF8Deserializer, pack_long, read_int, write_int | ||
| from pyspark.join import python_join, python_left_outer_join, \ | ||
| python_right_outer_join, python_full_outer_join, python_cogroup | ||
| from pyspark.statcounter import StatCounter | ||
|
|
@@ -138,15 +138,69 @@ def _parse_memory(s): | |
| return int(float(s[:-1]) * units[s[-1].lower()]) | ||
|
|
||
|
|
||
| def _load_from_socket(sock_info, serializer): | ||
| def _create_local_socket(sock_info): | ||
| (sockfile, sock) = local_connect_and_auth(*sock_info) | ||
| # The RDD materialization time is unpredicable, if we set a timeout for socket reading | ||
| # The RDD materialization time is unpredictable, if we set a timeout for socket reading | ||
| # operation, it will very possibly fail. See SPARK-18281. | ||
| sock.settimeout(None) | ||
| return sockfile | ||
|
|
||
|
|
||
| def _load_from_socket(sock_info, serializer): | ||
| sockfile = _create_local_socket(sock_info) | ||
| # The socket will be automatically closed when garbage-collected. | ||
|
||
| return serializer.load_stream(sockfile) | ||
|
|
||
|
|
||
| def _local_iterator_from_socket(sock_info, serializer): | ||
|
|
||
| class PyLocalIterable(object): | ||
| """ Create a synchronous local iterable over a socket """ | ||
|
|
||
| def __init__(self, _sock_info, _serializer): | ||
| self._sockfile = _create_local_socket(_sock_info) | ||
| self._serializer = _serializer | ||
| self._read_iter = iter([]) # Initialize as empty iterator | ||
| self._read_status = 1 | ||
|
|
||
| def __iter__(self): | ||
| while self._read_status == 1: | ||
| # Request next partition data from Java | ||
| write_int(1, self._sockfile) | ||
| self._sockfile.flush() | ||
|
|
||
| # If response is 1 then there is a partition to read, if 0 then fully consumed | ||
| self._read_status = read_int(self._sockfile) | ||
| if self._read_status == 1: | ||
|
|
||
| # Load the partition data as a stream and read each item | ||
| self._read_iter = self._serializer.load_stream(self._sockfile) | ||
| for item in self._read_iter: | ||
| yield item | ||
|
|
||
| # An error occurred, read error message and raise it | ||
| elif self._read_status == -1: | ||
| error_msg = UTF8Deserializer().loads(self._sockfile) | ||
| raise RuntimeError("An error occurred while reading the next element from " | ||
| "toLocalIterator: {}".format(error_msg)) | ||
|
|
||
| def __del__(self): | ||
| # If local iterator is not fully consumed, | ||
| if self._read_status == 1: | ||
| try: | ||
| # Finish consuming partition data stream | ||
| for _ in self._read_iter: | ||
| pass | ||
| # Tell Java to stop sending data and close connection | ||
| write_int(0, self._sockfile) | ||
| self._sockfile.flush() | ||
| except Exception: | ||
| # Ignore any errors, socket is automatically closed when garbage-collected | ||
| pass | ||
|
|
||
| return iter(PyLocalIterable(sock_info, serializer)) | ||
|
|
||
|
|
||
| def ignore_unicode_prefix(f): | ||
| """ | ||
| Ignore the 'u' prefix of string in doc tests, to make it works | ||
|
|
@@ -2386,7 +2440,7 @@ def toLocalIterator(self): | |
| """ | ||
| with SCCallSiteSync(self.context) as css: | ||
| sock_info = self.ctx._jvm.PythonRDD.toLocalIteratorAndServe(self._jrdd.rdd()) | ||
| return _load_from_socket(sock_info, self._jrdd_deserializer) | ||
| return _local_iterator_from_socket(sock_info, self._jrdd_deserializer) | ||
|
|
||
| def barrier(self): | ||
| """ | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -677,6 +677,34 @@ def test_repr_behaviors(self): | |
| self.assertEquals(None, df._repr_html_()) | ||
| self.assertEquals(expected, df.__repr__()) | ||
|
|
||
| def test_to_local_iterator(self): | ||
| df = self.spark.range(8, numPartitions=4) | ||
| expected = df.collect() | ||
| it = df.toLocalIterator() | ||
| self.assertEqual(expected, list(it)) | ||
|
|
||
| # Test DataFrame with empty partition | ||
| df = self.spark.range(3, numPartitions=4) | ||
| it = df.toLocalIterator() | ||
| expected = df.collect() | ||
| self.assertEqual(expected, list(it)) | ||
|
|
||
| def test_to_local_iterator_not_fully_consumed(self): | ||
| # SPARK-23961: toLocalIterator throws exception when not fully consumed | ||
| # Create a DataFrame large enough so that write to socket will eventually block | ||
| df = self.spark.range(1 << 20, numPartitions=2) | ||
| it = df.toLocalIterator() | ||
| self.assertEqual(df.take(1)[0], next(it)) | ||
| with QuietTest(self.sc): | ||
| it = None # remove iterator from scope, socket is closed when cleaned up | ||
| # Make sure normal df operations still work | ||
| result = [] | ||
| for i, row in enumerate(df.toLocalIterator()): | ||
| result.append(row) | ||
| if i == 7: | ||
| break | ||
| self.assertEqual(df.take(8), result) | ||
|
||
|
|
||
|
|
||
| class QueryExecutionListenerTests(unittest.TestCase, SQLTestUtils): | ||
| # These tests are separate because it uses 'spark.sql.queryExecutionListeners' which is | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -60,15 +60,12 @@ def test_sum(self): | |
| self.assertEqual(6, self.sc.parallelize([1, 2, 3]).sum()) | ||
|
|
||
| def test_to_localiterator(self): | ||
| from time import sleep | ||
| rdd = self.sc.parallelize([1, 2, 3]) | ||
| it = rdd.toLocalIterator() | ||
| sleep(5) | ||
|
||
| self.assertEqual([1, 2, 3], sorted(it)) | ||
|
|
||
| rdd2 = rdd.repartition(1000) | ||
| it2 = rdd2.toLocalIterator() | ||
| sleep(5) | ||
|
||
| self.assertEqual([1, 2, 3], sorted(it2)) | ||
|
|
||
| def test_save_as_textfile_with_unicode(self): | ||
|
|
@@ -736,6 +733,34 @@ def test_overwritten_global_func(self): | |
| global_func = lambda: "Yeah" | ||
| self.assertEqual(self.sc.parallelize([1]).map(lambda _: global_func()).first(), "Yeah") | ||
|
|
||
| def test_to_local_iterator_failure(self): | ||
| # SPARK-27548 toLocalIterator task failure not propagated to Python driver | ||
|
|
||
| def fail(_): | ||
| raise RuntimeError("local iterator error") | ||
|
|
||
| rdd = self.sc.range(10).map(fail) | ||
|
|
||
| with self.assertRaisesRegexp(Exception, "local iterator error"): | ||
| for _ in rdd.toLocalIterator(): | ||
| pass | ||
|
|
||
| def test_to_local_iterator_collects_single_partition(self): | ||
| # Test that partitions are not computed until requested by iteration | ||
|
|
||
| def fail_last(x): | ||
| if x == 9: | ||
| raise RuntimeError("This should not be hit") | ||
| return x | ||
|
|
||
| rdd = self.sc.range(12, numSlices=4).map(fail_last) | ||
| it = rdd.toLocalIterator() | ||
|
|
||
| # Only consume first 4 elements from partitions 1 and 2, this should not collect the last | ||
| # partition which would trigger the error | ||
| for i in range(4): | ||
| self.assertEqual(i, next(it)) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import unittest | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the local iterator is out of scope in Python side, will remaining jobs still be triggered after at Scala side it can't write into the closed connection?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, the remaining jobs are not triggered. The python iterator finishes consuming the data from the current job, then sends a command for Scala iterator to stop.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
How about previous behavior? The behavior before will trigger them? Looks like
toLocalIteratorwon't trigger the job if we don't iterate the data on a partition.Uh oh!
There was an error while loading. Please reload this page.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior was that the Scala local iterator would advance as long as the
writecalls to the socket are not blocked. So this means when Python reads a batch (auto-batched elements) from the current partition, this will unblock the Scala call to write and could start a job to collect the next partition.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Once the local iterator at Python side is out of scope and so the iterator is not fully consumed, will it block the write call at Scala? Seems to me that it will and we shouldn't see unneeded jobs to be triggered after that, doesn't?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The previous behavior is when the iterator goes out of scope, the socket eventually is closed. This creates the error on the Scala side and the writing thread is terminated, so no more jobs are triggered but the user sees this error.