Skip to content
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .rat-excludes
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ sorttable.js
.*data
.*log
cloudpickle.py
heapq3.py
join.py
SparkExprTyper.scala
SparkILoop.scala
Expand Down
890 changes: 890 additions & 0 deletions python/pyspark/heapq3.py

Large diffs are not rendered by default.

23 changes: 13 additions & 10 deletions python/pyspark/rdd.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@
from pyspark.storagelevel import StorageLevel
from pyspark.resultiterable import ResultIterable
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, \
get_used_memory
get_used_memory, ExternalSorter

from py4j.java_collections import ListConverter, MapConverter

Expand Down Expand Up @@ -587,14 +587,19 @@ def sortByKey(self, ascending=True, numPartitions=None, keyfunc=lambda x: x):
if numPartitions is None:
numPartitions = self._defaultReducePartitions()

spill = (self.ctx._conf.get("spark.shuffle.spill", 'True').lower() == 'true')
memory = _parse_memory(self.ctx._conf.get("spark.python.worker.memory", "512m"))
serializer = self._jrdd_deserializer

def sortPartition(iterator):
if spill:
sorted = ExternalSorter(memory * 0.9, serializer).sorted
return sorted(iterator, key=lambda (k, v): keyfunc(k), reverse=(not ascending))

if numPartitions == 1:
if self.getNumPartitions() > 1:
self = self.coalesce(1)

def sort(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.mapPartitions(sort)
return self.mapPartitions(sortPartition)

# first compute the boundary of each part via sampling: we want to partition
# the key-space into bins such that the bins have roughly the same
Expand All @@ -617,10 +622,8 @@ def rangePartitionFunc(k):
else:
return numPartitions - 1 - p

def mapFunc(iterator):
return sorted(iterator, reverse=(not ascending), key=lambda (k, v): keyfunc(k))

return self.partitionBy(numPartitions, rangePartitionFunc).mapPartitions(mapFunc, True)
return (self.partitionBy(numPartitions, rangePartitionFunc)
.mapPartitions(sortPartition, True))

def sortBy(self, keyfunc, ascending=True, numPartitions=None):
"""
Expand Down
76 changes: 68 additions & 8 deletions python/pyspark/shuffle.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,10 @@
import shutil
import warnings
import gc
import itertools
import operator

import pyspark.heapq3 as heapq
from pyspark.serializers import BatchedSerializer, PickleSerializer

try:
Expand Down Expand Up @@ -54,6 +57,13 @@ def get_used_memory():
return 0


def _get_local_dirs(sub):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of my new PR, this will need to be changed to "SPARK_LOCAL_DIRS" (plural).

dirs = path.split(",")
return [os.path.join(d, "python", str(os.getpid()), sub) for d in dirs]


class Aggregator(object):

"""
Expand Down Expand Up @@ -196,7 +206,7 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
# default serializer is only used for tests
self.serializer = serializer or \
BatchedSerializer(PickleSerializer(), 1024)
self.localdirs = localdirs or self._get_dirs()
self.localdirs = localdirs or _get_local_dirs(str(id(self)))
# number of partitions when spill data into disks
self.partitions = partitions
# check the memory after # of items merged
Expand All @@ -212,13 +222,6 @@ def __init__(self, aggregator, memory_limit=512, serializer=None,
# randomize the hash of key, id(o) is the address of o (aligned by 8)
self._seed = id(self) + 7

def _get_dirs(self):
""" Get all the directories """
path = os.environ.get("SPARK_LOCAL_DIR", "/tmp")
dirs = path.split(",")
return [os.path.join(d, "python", str(os.getpid()), str(id(self)))
for d in dirs]

def _get_spill_dir(self, n):
""" Choose one directory for spill by number n """
return os.path.join(self.localdirs[n % len(self.localdirs)], str(n))
Expand Down Expand Up @@ -434,6 +437,63 @@ def _recursive_merged_items(self, start):
os.remove(os.path.join(path, str(i)))


class ExternalSorter(object):
"""
ExtenalSorter will divide the elements into chunks, sort them in
memory and dump them into disks, finally merge them back.

The spilling will only happen when the used memory goes above
the limit.
"""
def __init__(self, memory_limit, serializer=None):
self.memory_limit = memory_limit
self.local_dirs = _get_local_dirs("sort")
self.serializer = serializer or BatchedSerializer(PickleSerializer(), 1024)

def _get_path(self, n):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because there will be multiple Python worker processes running on the same node, if they all need to spill, it looks like they'll use the same directories in order here. Can you instead start each one at a random ID and then increment that to have it cycle through?

I'm not sure whether this can also affect the external hashing code, but if so, it would be good to fix that too (as a separate JIRA).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Basically I'm worried that everyone writes to disk1 first, then everyone writes to disk2, etc, and we only use one disk at a time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch, maybe shuffling the directories randomly in the begging would be better.

PS: Could you have a configured policy to choose local disks, such as use the first one AMAP, it's will be useful when one of the local disks is SSD.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah good question. We don't have that yet, but in the future we'll have support for multiple local storage levels.

""" Choose one directory for spill by number n """
d = self.local_dirs[n % len(self.local_dirs)]
if not os.path.exists(d):
os.makedirs(d)
return os.path.join(d, str(n))

def sorted(self, iterator, key=None, reverse=False):
"""
Sort the elements in iterator, do external sort when the memory
goes above the limit.
"""
batch = 10
chunks, current_chunk = [], []
iterator = iter(iterator)
while True:
# pick elements in batch
chunk = list(itertools.islice(iterator, batch))
current_chunk.extend(chunk)
if len(chunk) < batch:
break

if get_used_memory() > self.memory_limit:
# sort them inplace will save memory
current_chunk.sort(key=key, reverse=reverse)
path = self._get_path(len(chunks))
with open(path, 'w') as f:
self.serializer.dump_stream(current_chunk, f)
chunks.append(self.serializer.load_stream(open(path)))
current_chunk = []

elif not chunks:
batch = min(batch * 2, 10000)

current_chunk.sort(key=key, reverse=reverse)
if not chunks:
return current_chunk

if current_chunk:
chunks.append(iter(current_chunk))

return heapq.merge(chunks, key=key, reverse=reverse)


if __name__ == "__main__":
import doctest
doctest.testmod()
25 changes: 24 additions & 1 deletion python/pyspark/tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import tempfile
import time
import zipfile
import random

if sys.version_info[:2] <= (2, 6):
import unittest2 as unittest
Expand All @@ -40,7 +41,7 @@
from pyspark.context import SparkContext
from pyspark.files import SparkFiles
from pyspark.serializers import read_int
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger
from pyspark.shuffle import Aggregator, InMemoryMerger, ExternalMerger, ExternalSorter

_have_scipy = False
_have_numpy = False
Expand Down Expand Up @@ -117,6 +118,28 @@ def test_huge_dataset(self):
m._cleanup()


class TestSorter(unittest.TestCase):
def test_in_memory_sort(self):
l = range(1024)
random.shuffle(l)
sorter = ExternalSorter(1024)
self.assertEquals(sorted(l), sorter.sorted(l))
self.assertEquals(sorted(l, reverse=True), sorter.sorted(l, reverse=True))
self.assertEquals(sorted(l, key=lambda x: -x), sorter.sorted(l, key=lambda x: -x))
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
sorter.sorted(l, key=lambda x: -x, reverse=True))

def test_external_sort(self):
l = range(100)
random.shuffle(l)
sorter = ExternalSorter(1)
self.assertEquals(sorted(l), list(sorter.sorted(l)))
self.assertEquals(sorted(l, reverse=True), list(sorter.sorted(l, reverse=True)))
self.assertEquals(sorted(l, key=lambda x: -x), list(sorter.sorted(l, key=lambda x: -x)))
self.assertEquals(sorted(l, key=lambda x: -x, reverse=True),
list(sorter.sorted(l, key=lambda x: -x, reverse=True)))

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Would be good to add a test that calls sortByKey on a large dataset here too, in addition to just testing the ExternalSorter separately.


class SerializationTestCase(unittest.TestCase):

def test_namedtuple(self):
Expand Down
2 changes: 1 addition & 1 deletion tox.ini
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

[pep8]
max-line-length=100
exclude=cloudpickle.py
exclude=cloudpickle.py,heapq3.py