-
Notifications
You must be signed in to change notification settings - Fork 29k
[SPARK-3073] [PySpark] use external sort in sortBy() and sortByKey() #1978
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from 1 commit
55602ee
19f7873
644abaf
eb53ca6
1f075ed
eae0176
b125d2f
bbcd9ba
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -31,6 +31,7 @@ sorttable.js | |
| .*data | ||
| .*log | ||
| cloudpickle.py | ||
| heapq3.py | ||
| join.py | ||
| SparkExprTyper.scala | ||
| SparkILoop.scala | ||
|
|
||
Large diffs are not rendered by default.
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -21,7 +21,10 @@ | |
| import shutil | ||
| import warnings | ||
| import gc | ||
| import itertools | ||
| import operator | ||
|
|
||
| import pyspark.heapq3 as heapq | ||
| from pyspark.serializers import BatchedSerializer, PickleSerializer | ||
|
|
||
| try: | ||
|
|
@@ -54,6 +57,13 @@ def get_used_memory(): | |
| return 0 | ||
|
|
||
|
|
||
| def _get_local_dirs(sub): | ||
| """ Get all the directories """ | ||
| path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") | ||
| dirs = path.split(",") | ||
| return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs] | ||
|
|
||
|
|
||
| class Aggregator(object): | ||
|
|
||
| """ | ||
|
|
@@ -196,7 +206,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, | |
| # default serializer is only used for tests | ||
| self.serializer = serializer or \ | ||
| BatchedSerializer(PickleSerializer(), 1024) | ||
| self.localdirs = localdirs or self._get_dirs() | ||
| self.localdirs = localdirs or _get_local_dirs(str(id(self))) | ||
| # number of partitions when spill data into disks | ||
| self.partitions = partitions | ||
| # check the memory after # of items merged | ||
|
|
@@ -212,13 +222,6 @@ def __init__(self, aggregator, memory_limit=512, serializer=None, | |
| # randomize the hash of key, id(o) is the address of o (aligned by 8) | ||
| self._seed = id(self) + 7 | ||
|
|
||
| def _get_dirs(self): | ||
| """ Get all the directories """ | ||
| path = os.environ.get("SPARK_LOCAL_DIR", "/tmp") | ||
| dirs = path.split(",") | ||
| return [os.path.join(d, "python", str(os.getpid()), str(id(self))) | ||
| for d in dirs] | ||
|
|
||
| def _get_spill_dir(self, n): | ||
| """ Choose one directory for spill by number n """ | ||
| return os.path.join(self.localdirs[n % len(self.localdirs)], str(n)) | ||
|
|
@@ -434,6 +437,63 @@ def _recursive_merged_items(self, start): | |
| os.remove(os.path.join(path, str(i))) | ||
|
|
||
|
|
||
| class ExternalSorter(object): | ||
| """ | ||
| ExtenalSorter will divide the elements into chunks, sort them in | ||
| memory and dump them into disks, finally merge them back. | ||
|
|
||
| The spilling will only happen when the used memory goes above | ||
| the limit. | ||
| """ | ||
| def __init__(self, memory_limit, serializer=None): | ||
| self.memory_limit = memory_limit | ||
| self.local_dirs = _get_local_dirs("sort") | ||
| self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024) | ||
|
|
||
| def _get_path(self, n): | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Because there will be multiple Python worker processes running on the same node, if they all need to spill, it looks like they'll use the same directories in order here. Can you instead start each one at a random ID and then increment that to have it cycle through? I'm not sure whether this can also affect the external hashing code, but if so, it would be good to fix that too (as a separate JIRA).
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Basically I'm worried that everyone writes to disk1 first, then everyone writes to disk2, etc, and we only use one disk at a time.
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Good catch, maybe shuffling the directories randomly in the begging would be better. PS: Could you have a configured policy to choose local disks, such as use the first one AMAP, it's will be useful when one of the local disks is SSD.
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah good question. We don't have that yet, but in the future we'll have support for multiple local storage levels. |
||
| """ Choose one directory for spill by number n """ | ||
| d = self.local_dirs[n % len(self.local_dirs)] | ||
| if not os.path.exists(d): | ||
| os.makedirs(d) | ||
| return os.path.join(d, str(n)) | ||
|
|
||
| def sorted(self, iterator, key=None, reverse=False): | ||
| """ | ||
| Sort the elements in iterator, do external sort when the memory | ||
| goes above the limit. | ||
| """ | ||
| batch = 10 | ||
| chunks, current_chunk = [], [] | ||
| iterator = iter(iterator) | ||
| while True: | ||
| # pick elements in batch | ||
| chunk = list(itertools.islice(iterator, batch)) | ||
| current_chunk.extend(chunk) | ||
| if len(chunk) < batch: | ||
| break | ||
|
|
||
| if get_used_memory() > self.memory_limit: | ||
| # sort them inplace will save memory | ||
| current_chunk.sort(key=key, reverse=reverse) | ||
| path = self._get_path(len(chunks)) | ||
| with open(path, 'w') as f: | ||
| self.serializer.dump_stream(current_chunk, f) | ||
| chunks.append(self.serializer.load_stream(open(path))) | ||
| current_chunk = [] | ||
|
|
||
| elif not chunks: | ||
| batch = min(batch * 2, 10000) | ||
|
|
||
| current_chunk.sort(key=key, reverse=reverse) | ||
| if not chunks: | ||
| return current_chunk | ||
|
|
||
| if current_chunk: | ||
| chunks.append(iter(current_chunk)) | ||
|
|
||
| return heapq.merge(chunks, key=key, reverse=reverse) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| import doctest | ||
| doctest.testmod() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -30,6 +30,7 @@ | |
| import tempfile | ||
| import time | ||
| import zipfile | ||
| import random | ||
|
|
||
| if sys.version_info[:2] <= (2, 6): | ||
| import unittest2 as unittest | ||
|
|
@@ -40,7 +41,7 @@ | |
| from pyspark.context import SparkContext | ||
| from pyspark.files import SparkFiles | ||
| from pyspark.serializers import read_int | ||
| from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger | ||
| from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter | ||
|
|
||
| _have_scipy = False | ||
| _have_numpy = False | ||
|
|
@@ -117,6 +118,28 @@ def test_huge_dataset(self): | |
| m._cleanup() | ||
|
|
||
|
|
||
| class TestSorter(unittest.TestCase): | ||
| def test_in_memory_sort(self): | ||
| l = range(1024) | ||
| random.shuffle(l) | ||
| sorter = ExternalSorter(1024) | ||
| self.assertEquals(sorted(l), sorter.sorted(l)) | ||
| self.assertEquals(sorted(l, reverse=True), sorter.sorted(l, reverse=True)) | ||
| self.assertEquals(sorted(l, key=lambda x: -x), sorter.sorted(l, key=lambda x: -x)) | ||
| self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), | ||
| sorter.sorted(l, key=lambda x: -x, reverse=True)) | ||
|
|
||
| def test_external_sort(self): | ||
| l = range(100) | ||
| random.shuffle(l) | ||
| sorter = ExternalSorter(1) | ||
| self.assertEquals(sorted(l), list(sorter.sorted(l))) | ||
| self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True))) | ||
| self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x))) | ||
| self.assertEquals(sorted(l, key=lambda x: -x, reverse=True), | ||
| list(sorter.sorted(l, key=lambda x: -x, reverse=True))) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Would be good to add a test that calls sortByKey on a large dataset here too, in addition to just testing the ExternalSorter separately. |
||
|
|
||
| class SerializationTestCase(unittest.TestCase): | ||
|
|
||
| def test_namedtuple(self): | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -15,4 +15,4 @@ | |
|
|
||
| [pep8] | ||
| max-line-length=100 | ||
| exclude=cloudpickle.py | ||
| exclude=cloudpickle.py,heapq3.py | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As of my new PR, this will need to be changed to "SPARK_LOCAL_DIRS" (plural).