Skip to content

Commit 133c8c9

Browse files
committed
WIP towards testing UnsafeShuffleWriter.
Unfortunately, this involved a TON of mocks; maybe it would be easier to split the writer into more objects, such as a spiller and merger, as I did when the sorting code was more generic.
1 parent f480fb2 commit 133c8c9

File tree

4 files changed

+215
-19
lines changed

4 files changed

+215
-19
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleSpillWriter.java

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -69,13 +69,13 @@ public final class UnsafeShuffleSpillWriter {
6969
private final LinkedList<SpillInfo> spills = new LinkedList<SpillInfo>();
7070

7171
public UnsafeShuffleSpillWriter(
72-
TaskMemoryManager memoryManager,
73-
ShuffleMemoryManager shuffleMemoryManager,
74-
BlockManager blockManager,
75-
TaskContext taskContext,
76-
int initialSize,
77-
int numPartitions,
78-
SparkConf conf) throws IOException {
72+
TaskMemoryManager memoryManager,
73+
ShuffleMemoryManager shuffleMemoryManager,
74+
BlockManager blockManager,
75+
TaskContext taskContext,
76+
int initialSize,
77+
int numPartitions,
78+
SparkConf conf) throws IOException {
7979
this.memoryManager = memoryManager;
8080
this.shuffleMemoryManager = shuffleMemoryManager;
8181
this.blockManager = blockManager;
@@ -266,7 +266,7 @@ public SpillInfo[] closeAndGetSpills() throws IOException {
266266
if (sorter != null) {
267267
writeSpillFile();
268268
}
269-
return (SpillInfo[]) spills.toArray();
269+
return spills.toArray(new SpillInfo[0]);
270270
}
271271

272272
}

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -23,9 +23,12 @@
2323
import java.io.IOException;
2424
import java.nio.ByteBuffer;
2525
import java.nio.channels.FileChannel;
26+
import java.util.Iterator;
2627

28+
import org.apache.spark.shuffle.ShuffleMemoryManager;
2729
import scala.Option;
2830
import scala.Product2;
31+
import scala.collection.JavaConversions;
2932
import scala.reflect.ClassTag;
3033
import scala.reflect.ClassTag$;
3134

@@ -50,14 +53,18 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
5053
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
5154
private static final ClassTag<Object> OBJECT_CLASS_TAG = ClassTag$.MODULE$.Object();
5255

56+
private final BlockManager blockManager;
5357
private final IndexShuffleBlockManager shuffleBlockManager;
54-
private final BlockManager blockManager = SparkEnv.get().blockManager();
55-
private final int shuffleId;
56-
private final int mapId;
5758
private final TaskMemoryManager memoryManager;
59+
private final ShuffleMemoryManager shuffleMemoryManager;
5860
private final SerializerInstance serializer;
5961
private final Partitioner partitioner;
6062
private final ShuffleWriteMetrics writeMetrics;
63+
private final int shuffleId;
64+
private final int mapId;
65+
private final TaskContext taskContext;
66+
private final SparkConf sparkConf;
67+
6168
private MapStatus mapStatus = null;
6269

6370
/**
@@ -68,19 +75,31 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
6875
private boolean stopping = false;
6976

7077
public UnsafeShuffleWriter(
78+
BlockManager blockManager,
7179
IndexShuffleBlockManager shuffleBlockManager,
80+
TaskMemoryManager memoryManager,
81+
ShuffleMemoryManager shuffleMemoryManager,
7282
UnsafeShuffleHandle<K, V> handle,
7383
int mapId,
74-
TaskContext context) {
84+
TaskContext taskContext,
85+
SparkConf sparkConf) {
86+
this.blockManager = blockManager;
7587
this.shuffleBlockManager = shuffleBlockManager;
88+
this.memoryManager = memoryManager;
89+
this.shuffleMemoryManager = shuffleMemoryManager;
7690
this.mapId = mapId;
77-
this.memoryManager = context.taskMemoryManager();
7891
final ShuffleDependency<K, V, V> dep = handle.dependency();
7992
this.shuffleId = dep.shuffleId();
8093
this.serializer = Serializer.getSerializer(dep.serializer()).newInstance();
8194
this.partitioner = dep.partitioner();
8295
this.writeMetrics = new ShuffleWriteMetrics();
83-
context.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
96+
taskContext.taskMetrics().shuffleWriteMetrics_$eq(Option.apply(writeMetrics));
97+
this.taskContext = taskContext;
98+
this.sparkConf = sparkConf;
99+
}
100+
101+
public void write(Iterator<Product2<K, V>> records) {
102+
write(JavaConversions.asScalaIterator(records));
84103
}
85104

86105
public void write(scala.collection.Iterator<Product2<K, V>> records) {
@@ -101,12 +120,12 @@ private SpillInfo[] insertRecordsIntoSorter(
101120
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
102121
final UnsafeShuffleSpillWriter sorter = new UnsafeShuffleSpillWriter(
103122
memoryManager,
104-
SparkEnv$.MODULE$.get().shuffleMemoryManager(),
105-
SparkEnv$.MODULE$.get().blockManager(),
106-
TaskContext.get(),
123+
shuffleMemoryManager,
124+
blockManager,
125+
taskContext,
107126
4096, // Initial size (TODO: tune this!)
108127
partitioner.numPartitions(),
109-
SparkEnv$.MODULE$.get().conf()
128+
sparkConf
110129
);
111130

112131
final byte[] serArray = new byte[SER_BUFFER_SIZE];

core/src/main/scala/org/apache/spark/shuffle/unsafe/UnsafeShuffleManager.scala

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,12 +88,17 @@ private[spark] class UnsafeShuffleManager(conf: SparkConf) extends ShuffleManage
8888
context: TaskContext): ShuffleWriter[K, V] = {
8989
handle match {
9090
case unsafeShuffleHandle: UnsafeShuffleHandle[K, V] =>
91+
val env = SparkEnv.get
9192
// TODO: do we need to do anything to register the shuffle here?
9293
new UnsafeShuffleWriter(
94+
env.blockManager,
9395
shuffleBlockResolver.asInstanceOf[IndexShuffleBlockManager],
96+
context.taskMemoryManager(),
97+
env.shuffleMemoryManager,
9498
unsafeShuffleHandle,
9599
mapId,
96-
context)
100+
context,
101+
env.conf)
97102
case other =>
98103
sortShuffleManager.getWriter(handle, mapId, context)
99104
}
Lines changed: 172 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.shuffle.unsafe;
19+
20+
import java.io.File;
21+
import java.io.InputStream;
22+
import java.io.OutputStream;
23+
import java.util.ArrayList;
24+
import java.util.UUID;
25+
26+
import scala.*;
27+
import scala.runtime.AbstractFunction1;
28+
29+
import org.junit.Assert;
30+
import org.junit.Before;
31+
import org.junit.Test;
32+
import org.mockito.invocation.InvocationOnMock;
33+
import org.mockito.stubbing.Answer;
34+
import static org.mockito.AdditionalAnswers.returnsFirstArg;
35+
import static org.mockito.AdditionalAnswers.returnsSecondArg;
36+
import static org.mockito.Mockito.*;
37+
38+
import org.apache.spark.*;
39+
import org.apache.spark.serializer.Serializer;
40+
import org.apache.spark.shuffle.IndexShuffleBlockManager;
41+
import org.apache.spark.executor.ShuffleWriteMetrics;
42+
import org.apache.spark.executor.TaskMetrics;
43+
import org.apache.spark.serializer.SerializerInstance;
44+
import org.apache.spark.shuffle.ShuffleMemoryManager;
45+
import org.apache.spark.storage.*;
46+
import org.apache.spark.unsafe.memory.ExecutorMemoryManager;
47+
import org.apache.spark.unsafe.memory.MemoryAllocator;
48+
import org.apache.spark.unsafe.memory.TaskMemoryManager;
49+
import org.apache.spark.util.Utils;
50+
import org.apache.spark.serializer.KryoSerializer;
51+
import org.apache.spark.scheduler.MapStatus;
52+
53+
public class UnsafeShuffleWriterSuite {
54+
55+
final TaskMemoryManager memoryManager =
56+
new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP));
57+
// Compute key prefixes based on the records' partition ids
58+
final HashPartitioner hashPartitioner = new HashPartitioner(4);
59+
60+
ShuffleMemoryManager shuffleMemoryManager;
61+
BlockManager blockManager;
62+
IndexShuffleBlockManager shuffleBlockManager;
63+
DiskBlockManager diskBlockManager;
64+
File tempDir;
65+
TaskContext taskContext;
66+
SparkConf sparkConf;
67+
68+
private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> {
69+
@Override
70+
public OutputStream apply(OutputStream stream) {
71+
return stream;
72+
}
73+
}
74+
75+
@Before
76+
public void setUp() {
77+
shuffleMemoryManager = mock(ShuffleMemoryManager.class);
78+
diskBlockManager = mock(DiskBlockManager.class);
79+
blockManager = mock(BlockManager.class);
80+
shuffleBlockManager = mock(IndexShuffleBlockManager.class);
81+
tempDir = new File(Utils.createTempDir$default$1());
82+
taskContext = mock(TaskContext.class);
83+
sparkConf = new SparkConf();
84+
when(taskContext.taskMetrics()).thenReturn(new TaskMetrics());
85+
when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg());
86+
when(blockManager.diskBlockManager()).thenReturn(diskBlockManager);
87+
when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() {
88+
@Override
89+
public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable {
90+
TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID());
91+
File file = File.createTempFile("spillFile", ".spill", tempDir);
92+
return Tuple2$.MODULE$.apply(blockId, file);
93+
}
94+
});
95+
when(blockManager.getDiskWriter(
96+
any(BlockId.class),
97+
any(File.class),
98+
any(SerializerInstance.class),
99+
anyInt(),
100+
any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() {
101+
@Override
102+
public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable {
103+
Object[] args = invocationOnMock.getArguments();
104+
105+
return new DiskBlockObjectWriter(
106+
(BlockId) args[0],
107+
(File) args[1],
108+
(SerializerInstance) args[2],
109+
(Integer) args[3],
110+
new CompressStream(),
111+
false,
112+
(ShuffleWriteMetrics) args[4]
113+
);
114+
}
115+
});
116+
when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class)))
117+
.then(returnsSecondArg());
118+
}
119+
120+
@Test
121+
public void basicShuffleWriting() throws Exception {
122+
123+
final ShuffleDependency<Object, Object, Object> dep = mock(ShuffleDependency.class);
124+
when(dep.serializer()).thenReturn(Option.<Serializer>apply(new KryoSerializer(sparkConf)));
125+
when(dep.partitioner()).thenReturn(hashPartitioner);
126+
127+
final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir);
128+
when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile);
129+
final long[] partitionSizes = new long[hashPartitioner.numPartitions()];
130+
doAnswer(new Answer<Void>() {
131+
@Override
132+
public Void answer(InvocationOnMock invocationOnMock) throws Throwable {
133+
long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2];
134+
System.arraycopy(
135+
receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length);
136+
return null;
137+
}
138+
}).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class));
139+
140+
final UnsafeShuffleWriter<Object, Object> writer = new UnsafeShuffleWriter<Object, Object>(
141+
blockManager,
142+
shuffleBlockManager,
143+
memoryManager,
144+
shuffleMemoryManager,
145+
new UnsafeShuffleHandle<Object, Object>(0, 1, dep),
146+
0, // map id
147+
taskContext,
148+
sparkConf
149+
);
150+
151+
final ArrayList<Product2<Object, Object>> numbersToSort =
152+
new ArrayList<Product2<Object, Object>>();
153+
numbersToSort.add(new Tuple2<Object, Object>(5, 5));
154+
numbersToSort.add(new Tuple2<Object, Object>(1, 1));
155+
numbersToSort.add(new Tuple2<Object, Object>(3, 3));
156+
numbersToSort.add(new Tuple2<Object, Object>(2, 2));
157+
numbersToSort.add(new Tuple2<Object, Object>(4, 4));
158+
159+
160+
writer.write(numbersToSort.iterator());
161+
final MapStatus mapStatus = writer.stop(true).get();
162+
163+
long sumOfPartitionSizes = 0;
164+
for (long size: partitionSizes) {
165+
sumOfPartitionSizes += size;
166+
}
167+
Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes);
168+
169+
// TODO: test that the temporary spill files were cleaned up after the merge.
170+
}
171+
172+
}

0 commit comments

Comments
 (0)