Skip to content

Commit 7cd013b

Browse files
committed
Begin refactoring to enable proper tests for spilling.
1 parent 722849b commit 7cd013b

File tree

3 files changed

+102
-37
lines changed

3 files changed

+102
-37
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,9 +21,9 @@
2121
import java.io.IOException;
2222
import java.util.LinkedList;
2323

24-
import org.apache.spark.storage.*;
2524
import scala.Tuple2;
2625

26+
import com.google.common.annotations.VisibleForTesting;
2727
import org.slf4j.Logger;
2828
import org.slf4j.LoggerFactory;
2929

@@ -32,6 +32,7 @@
3232
import org.apache.spark.executor.ShuffleWriteMetrics;
3333
import org.apache.spark.serializer.SerializerInstance;
3434
import org.apache.spark.shuffle.ShuffleMemoryManager;
35+
import org.apache.spark.storage.*;
3536
import org.apache.spark.unsafe.PlatformDependent;
3637
import org.apache.spark.unsafe.memory.MemoryBlock;
3738
import org.apache.spark.unsafe.memory.TaskMemoryManager;
@@ -215,7 +216,8 @@ private SpillInfo writeSpillFile() throws IOException {
215216
/**
216217
* Sort and spill the current records in response to memory pressure.
217218
*/
218-
private void spill() throws IOException {
219+
@VisibleForTesting
220+
void spill() throws IOException {
219221
final long threadId = Thread.currentThread().getId();
220222
logger.info("Thread " + threadId + " spilling sort data of " +
221223
org.apache.spark.util.Utils.bytesToString(getMemoryUsage()) + " to disk (" +

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import scala.reflect.ClassTag$;
3333

3434
import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
35+
import com.google.common.annotations.VisibleForTesting;
3536
import com.google.common.io.ByteStreams;
3637
import com.google.common.io.Files;
3738
import org.slf4j.Logger;
@@ -73,6 +74,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
7374
private final boolean transferToEnabled;
7475

7576
private MapStatus mapStatus = null;
77+
private UnsafeShuffleExternalSorter sorter = null;
78+
private byte[] serArray = null;
79+
private ByteBuffer serByteBuffer;
80+
// TODO: we should not depend on this class from Kryo; copy its source or find an alternative
81+
private SerializationStream serOutputStream;
7682

7783
/**
7884
* Are we in the process of stopping? Because map tasks can call stop() with success = true
@@ -113,56 +119,72 @@ public void write(Iterator<Product2<K, V>> records) {
113119
@Override
114120
public void write(scala.collection.Iterator<Product2<K, V>> records) {
115121
try {
116-
final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records));
117-
shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths);
118-
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
122+
while (records.hasNext()) {
123+
insertRecordIntoSorter(records.next());
124+
}
125+
closeAndWriteOutput();
119126
} catch (Exception e) {
120127
PlatformDependent.throwException(e);
121128
}
122129
}
123130

124-
private void freeMemory() {
125-
// TODO
126-
}
127-
128-
private void deleteSpills() {
129-
// TODO
130-
}
131-
132-
private SpillInfo[] insertRecordsIntoSorter(
133-
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
134-
final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter(
131+
private void open() throws IOException {
132+
assert (sorter == null);
133+
sorter = new UnsafeShuffleExternalSorter(
135134
memoryManager,
136135
shuffleMemoryManager,
137136
blockManager,
138137
taskContext,
139138
4096, // Initial size (TODO: tune this!)
140139
partitioner.numPartitions(),
141140
sparkConf);
142-
143-
final byte[] serArray = new byte[SER_BUFFER_SIZE];
144-
final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray);
141+
serArray = new byte[SER_BUFFER_SIZE];
142+
serByteBuffer = ByteBuffer.wrap(serArray);
145143
// TODO: we should not depend on this class from Kryo; copy its source or find an alternative
146-
final SerializationStream serOutputStream =
147-
serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
144+
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
145+
}
148146

149-
while (records.hasNext()) {
150-
final Product2<K, V> record = records.next();
151-
final K key = record._1();
152-
final int partitionId = partitioner.getPartition(key);
153-
serByteBuffer.position(0);
154-
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
155-
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
156-
serOutputStream.flush();
147+
@VisibleForTesting
148+
void closeAndWriteOutput() throws IOException {
149+
if (sorter == null) {
150+
open();
151+
}
152+
serArray = null;
153+
serByteBuffer = null;
154+
serOutputStream = null;
155+
final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills());
156+
sorter = null;
157+
shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths);
158+
mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths);
159+
}
157160

158-
final int serializedRecordSize = serByteBuffer.position();
159-
assert (serializedRecordSize > 0);
161+
private void freeMemory() {
162+
// TODO
163+
}
160164

161-
sorter.insertRecord(
162-
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
165+
@VisibleForTesting
166+
void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
167+
if (sorter == null) {
168+
open();
163169
}
170+
final K key = record._1();
171+
final int partitionId = partitioner.getPartition(key);
172+
serByteBuffer.position(0);
173+
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
174+
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
175+
serOutputStream.flush();
164176

165-
return sorter.closeAndGetSpills();
177+
final int serializedRecordSize = serByteBuffer.position();
178+
assert (serializedRecordSize > 0);
179+
180+
sorter.insertRecord(
181+
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
182+
}
183+
184+
@VisibleForTesting
185+
void forceSorterToSpill() throws IOException {
186+
assert (sorter != null);
187+
sorter.spill();
166188
}
167189

168190
private long[] mergeSpills(SpillInfo[] spills) throws IOException {
@@ -222,6 +244,9 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th
222244
for (int i = 0; i < spills.length; i++) {
223245
if (spillInputStreams[i] != null) {
224246
spillInputStreams[i].close();
247+
if (!spills[i].file.delete()) {
248+
logger.error("Error while deleting spill file {}", spills[i]);
249+
}
225250
}
226251
}
227252
if (mergedFileOutputStream != null) {
@@ -282,6 +307,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th
282307
assert(spillInputChannelPositions[i] == spills[i].file.length());
283308
if (spillInputChannels[i] != null) {
284309
spillInputChannels[i].close();
310+
if (!spills[i].file.delete()) {
311+
logger.error("Error while deleting spill file {}", spills[i]);
312+
}
285313
}
286314
}
287315
if (mergedFileOutputChannel != null) {

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -193,13 +193,13 @@ public void writeEmptyIterator() throws Exception {
193193
@Test
194194
public void writeWithoutSpilling() throws Exception {
195195
// In this example, each partition should have exactly one record:
196-
final ArrayList<Product2<Object, Object>> datatToWrite =
196+
final ArrayList<Product2<Object, Object>> dataToWrite =
197197
new ArrayList<Product2<Object, Object>>();
198198
for (int i = 0; i < NUM_PARTITITONS; i++) {
199-
datatToWrite.add(new Tuple2<Object, Object>(i, i));
199+
dataToWrite.add(new Tuple2<Object, Object>(i, i));
200200
}
201201
final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
202-
writer.write(datatToWrite.iterator());
202+
writer.write(dataToWrite.iterator());
203203
final Option<MapStatus> mapStatus = writer.stop(true);
204204
Assert.assertTrue(mapStatus.isDefined());
205205
Assert.assertTrue(mergedOutputFile.exists());
@@ -215,7 +215,42 @@ public void writeWithoutSpilling() throws Exception {
215215
assertSpillFilesWereCleanedUp();
216216
}
217217

218+
private void testMergingSpills(boolean transferToEnabled) throws IOException {
219+
final UnsafeShuffleWriter<Object, Object> writer = createWriter(true);
220+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
221+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
222+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(3, 3));
223+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
224+
writer.forceSorterToSpill();
225+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(4, 4));
226+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(2, 2));
227+
writer.closeAndWriteOutput();
228+
final Option<MapStatus> mapStatus = writer.stop(true);
229+
Assert.assertTrue(mapStatus.isDefined());
230+
Assert.assertTrue(mergedOutputFile.exists());
231+
Assert.assertEquals(2, spillFilesCreated.size());
232+
233+
long sumOfPartitionSizes = 0;
234+
for (long size: partitionSizesInMergedFile) {
235+
sumOfPartitionSizes += size;
236+
}
237+
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
238+
239+
assertSpillFilesWereCleanedUp();
240+
}
241+
242+
@Test
243+
public void mergeSpillsWithTransferTo() throws Exception {
244+
testMergingSpills(true);
245+
}
246+
247+
@Test
248+
public void mergeSpillsWithFileStream() throws Exception {
249+
testMergingSpills(false);
250+
}
251+
218252
// TODO: actually try to read the shuffle output?
219253
// TODO: add a test that manually triggers spills in order to exercise the merging.
254+
// }
220255

221256
}

0 commit comments

Comments
 (0)