Skip to content

Commit 27b18b0

Browse files
committed
That for inserting records AT the max record size.
1 parent fcd9a3c commit 27b18b0

File tree

3 files changed

+70
-26
lines changed

3 files changed

+70
-26
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleExternalSorter.java

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,9 +57,11 @@ final class UnsafeShuffleExternalSorter {
5757

5858
private final Logger logger = LoggerFactory.getLogger(UnsafeShuffleExternalSorter.class);
5959

60+
private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
6061
@VisibleForTesting
6162
static final int DISK_WRITE_BUFFER_SIZE = 1024 * 1024;
62-
private static final int PAGE_SIZE = PackedRecordPointer.MAXIMUM_PAGE_SIZE_BYTES;
63+
@VisibleForTesting
64+
static final int MAX_RECORD_SIZE = PAGE_SIZE - 4;
6365

6466
private final int initialSize;
6567
private final int numPartitions;

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 14 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
package org.apache.spark.shuffle.unsafe;
1919

2020
import java.io.*;
21-
import java.nio.ByteBuffer;
2221
import java.nio.channels.FileChannel;
2322
import java.util.Iterator;
2423

@@ -73,8 +72,14 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> {
7372

7473
private MapStatus mapStatus = null;
7574
private UnsafeShuffleExternalSorter sorter = null;
76-
private byte[] serArray = null;
77-
private ByteBuffer serByteBuffer;
75+
76+
/** Subclass of ByteArrayOutputStream that exposes `buf` directly. */
77+
private static final class MyByteArrayOutputStream extends ByteArrayOutputStream {
78+
public MyByteArrayOutputStream(int size) { super(size); }
79+
public byte[] getBuf() { return buf; }
80+
}
81+
82+
private MyByteArrayOutputStream serBuffer;
7883
private SerializationStream serOutputStream;
7984

8085
/**
@@ -142,18 +147,16 @@ private void open() throws IOException {
142147
4096, // Initial size (TODO: tune this!)
143148
partitioner.numPartitions(),
144149
sparkConf);
145-
serArray = new byte[MAXIMUM_RECORD_SIZE];
146-
serByteBuffer = ByteBuffer.wrap(serArray);
147-
serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer));
150+
serBuffer = new MyByteArrayOutputStream(1024 * 1024);
151+
serOutputStream = serializer.serializeStream(serBuffer);
148152
}
149153

150154
@VisibleForTesting
151155
void closeAndWriteOutput() throws IOException {
152156
if (sorter == null) {
153157
open();
154158
}
155-
serArray = null;
156-
serByteBuffer = null;
159+
serBuffer = null;
157160
serOutputStream = null;
158161
final SpillInfo[] spills = sorter.closeAndGetSpills();
159162
sorter = null;
@@ -178,16 +181,16 @@ void insertRecordIntoSorter(Product2<K, V> record) throws IOException{
178181
}
179182
final K key = record._1();
180183
final int partitionId = partitioner.getPartition(key);
181-
serByteBuffer.position(0);
184+
serBuffer.reset();
182185
serOutputStream.writeKey(key, OBJECT_CLASS_TAG);
183186
serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG);
184187
serOutputStream.flush();
185188

186-
final int serializedRecordSize = serByteBuffer.position();
189+
final int serializedRecordSize = serBuffer.size();
187190
assert (serializedRecordSize > 0);
188191

189192
sorter.insertRecord(
190-
serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
193+
serBuffer.getBuf(), PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId);
191194
}
192195

193196
@VisibleForTesting

core/src/test/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriterSuite.java

Lines changed: 53 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323

2424
import scala.*;
2525
import scala.collection.Iterator;
26+
import scala.reflect.ClassTag;
2627
import scala.runtime.AbstractFunction1;
2728

2829
import com.google.common.collect.HashMultiset;
@@ -44,11 +45,8 @@
4445
import org.apache.spark.executor.ShuffleWriteMetrics;
4546
import org.apache.spark.executor.TaskMetrics;
4647
import org.apache.spark.network.util.LimitedInputStream;
48+
import org.apache.spark.serializer.*;
4749
import org.apache.spark.scheduler.MapStatus;
48-
import org.apache.spark.serializer.DeserializationStream;
49-
import org.apache.spark.serializer.KryoSerializer;
50-
import org.apache.spark.serializer.Serializer;
51-
import org.apache.spark.serializer.SerializerInstance;
5250
import org.apache.spark.shuffle.IndexShuffleBlockResolver;
5351
import org.apache.spark.shuffle.ShuffleMemoryManager;
5452
import org.apache.spark.storage.*;
@@ -305,18 +303,59 @@ public void writeRecordsThatAreBiggerThanDiskWriteBufferSize() throws Exception
305303
}
306304

307305
@Test
308-
public void writeRecordsThatAreBiggerThanMaximumRecordSize() throws Exception {
306+
public void writeRecordsThatAreBiggerThanMaxRecordSize() throws Exception {
307+
// Use a custom serializer so that we have exact control over the size of serialized data.
308+
final Serializer byteArraySerializer = new Serializer() {
309+
@Override
310+
public SerializerInstance newInstance() {
311+
return new SerializerInstance() {
312+
@Override
313+
public SerializationStream serializeStream(final OutputStream s) {
314+
return new SerializationStream() {
315+
@Override
316+
public void flush() { }
317+
318+
@Override
319+
public <T> SerializationStream writeObject(T t, ClassTag<T> ev1) {
320+
byte[] bytes = (byte[]) t;
321+
try {
322+
s.write(bytes);
323+
} catch (IOException e) {
324+
throw new RuntimeException(e);
325+
}
326+
return this;
327+
}
328+
329+
@Override
330+
public void close() { }
331+
};
332+
}
333+
public <T> ByteBuffer serialize(T t, ClassTag<T> ev1) { return null; }
334+
public DeserializationStream deserializeStream(InputStream s) { return null; }
335+
public <T> T deserialize(ByteBuffer b, ClassLoader l, ClassTag<T> ev1) { return null; }
336+
public <T> T deserialize(ByteBuffer bytes, ClassTag<T> ev1) { return null; }
337+
};
338+
}
339+
};
340+
when(shuffleDep.serializer()).thenReturn(Option.<Serializer>apply(byteArraySerializer));
309341
final UnsafeShuffleWriter<Object, Object> writer = createWriter(false);
310-
final ArrayList<Product2<Object, Object>> dataToWrite =
311-
new ArrayList<Product2<Object, Object>>();
312-
final byte[] bytes = new byte[UnsafeShuffleWriter.MAXIMUM_RECORD_SIZE * 2];
313-
new Random(42).nextBytes(bytes);
314-
dataToWrite.add(new Tuple2<Object, Object>(1, bytes));
342+
// Insert a record and force a spill so that there's something to clean up:
343+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[1], new byte[1]));
344+
writer.forceSorterToSpill();
345+
// We should be able to write a record that's right _at_ the max record size
346+
final byte[] atMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE];
347+
new Random(42).nextBytes(atMaxRecordSize);
348+
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(new byte[0], atMaxRecordSize));
349+
writer.forceSorterToSpill();
350+
// Inserting a record that's larger than the max record size should fail:
351+
final byte[] exceedsMaxRecordSize = new byte[UnsafeShuffleExternalSorter.MAX_RECORD_SIZE + 1];
352+
new Random(42).nextBytes(exceedsMaxRecordSize);
353+
Product2<Object, Object> hugeRecord =
354+
new Tuple2<Object, Object>(new byte[0], exceedsMaxRecordSize);
315355
try {
316-
// Insert a record and force a spill so that there's something to clean up:
317-
writer.insertRecordIntoSorter(new Tuple2<Object, Object>(1, 1));
318-
writer.forceSorterToSpill();
319-
writer.write(dataToWrite.iterator());
356+
// Here, we write through the public `write()` interface instead of the test-only
357+
// `insertRecordIntoSorter` interface:
358+
writer.write(Collections.singletonList(hugeRecord).iterator());
320359
Assert.fail("Expected exception to be thrown");
321360
} catch (IOException e) {
322361
// Pass

0 commit comments

Comments
 (0)