|
| 1 | +/* |
| 2 | + * Licensed to the Apache Software Foundation (ASF) under one or more |
| 3 | + * contributor license agreements. See the NOTICE file distributed with |
| 4 | + * this work for additional information regarding copyright ownership. |
| 5 | + * The ASF licenses this file to You under the Apache License, Version 2.0 |
| 6 | + * (the "License"); you may not use this file except in compliance with |
| 7 | + * the License. You may obtain a copy of the License at |
| 8 | + * |
| 9 | + * http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | + * |
| 11 | + * Unless required by applicable law or agreed to in writing, software |
| 12 | + * distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | + * See the License for the specific language governing permissions and |
| 15 | + * limitations under the License. |
| 16 | + */ |
| 17 | + |
| 18 | +package org.apache.spark.shuffle.unsafe; |
| 19 | + |
| 20 | +import java.io.File; |
| 21 | +import java.io.InputStream; |
| 22 | +import java.io.OutputStream; |
| 23 | +import java.util.ArrayList; |
| 24 | +import java.util.UUID; |
| 25 | + |
| 26 | +import scala.*; |
| 27 | +import scala.runtime.AbstractFunction1; |
| 28 | + |
| 29 | +import org.junit.Assert; |
| 30 | +import org.junit.Before; |
| 31 | +import org.junit.Test; |
| 32 | +import org.mockito.invocation.InvocationOnMock; |
| 33 | +import org.mockito.stubbing.Answer; |
| 34 | +import static org.mockito.AdditionalAnswers.returnsFirstArg; |
| 35 | +import static org.mockito.AdditionalAnswers.returnsSecondArg; |
| 36 | +import static org.mockito.Mockito.*; |
| 37 | + |
| 38 | +import org.apache.spark.*; |
| 39 | +import org.apache.spark.serializer.Serializer; |
| 40 | +import org.apache.spark.shuffle.IndexShuffleBlockManager; |
| 41 | +import org.apache.spark.executor.ShuffleWriteMetrics; |
| 42 | +import org.apache.spark.executor.TaskMetrics; |
| 43 | +import org.apache.spark.serializer.SerializerInstance; |
| 44 | +import org.apache.spark.shuffle.ShuffleMemoryManager; |
| 45 | +import org.apache.spark.storage.*; |
| 46 | +import org.apache.spark.unsafe.memory.ExecutorMemoryManager; |
| 47 | +import org.apache.spark.unsafe.memory.MemoryAllocator; |
| 48 | +import org.apache.spark.unsafe.memory.TaskMemoryManager; |
| 49 | +import org.apache.spark.util.Utils; |
| 50 | +import org.apache.spark.serializer.KryoSerializer; |
| 51 | +import org.apache.spark.scheduler.MapStatus; |
| 52 | + |
| 53 | +public class UnsafeShuffleWriterSuite { |
| 54 | + |
| 55 | + final TaskMemoryManager memoryManager = |
| 56 | + new TaskMemoryManager(new ExecutorMemoryManager(MemoryAllocator.HEAP)); |
| 57 | + // Compute key prefixes based on the records' partition ids |
| 58 | + final HashPartitioner hashPartitioner = new HashPartitioner(4); |
| 59 | + |
| 60 | + ShuffleMemoryManager shuffleMemoryManager; |
| 61 | + BlockManager blockManager; |
| 62 | + IndexShuffleBlockManager shuffleBlockManager; |
| 63 | + DiskBlockManager diskBlockManager; |
| 64 | + File tempDir; |
| 65 | + TaskContext taskContext; |
| 66 | + SparkConf sparkConf; |
| 67 | + |
| 68 | + private static final class CompressStream extends AbstractFunction1<OutputStream, OutputStream> { |
| 69 | + @Override |
| 70 | + public OutputStream apply(OutputStream stream) { |
| 71 | + return stream; |
| 72 | + } |
| 73 | + } |
| 74 | + |
| 75 | + @Before |
| 76 | + public void setUp() { |
| 77 | + shuffleMemoryManager = mock(ShuffleMemoryManager.class); |
| 78 | + diskBlockManager = mock(DiskBlockManager.class); |
| 79 | + blockManager = mock(BlockManager.class); |
| 80 | + shuffleBlockManager = mock(IndexShuffleBlockManager.class); |
| 81 | + tempDir = new File(Utils.createTempDir$default$1()); |
| 82 | + taskContext = mock(TaskContext.class); |
| 83 | + sparkConf = new SparkConf(); |
| 84 | + when(taskContext.taskMetrics()).thenReturn(new TaskMetrics()); |
| 85 | + when(shuffleMemoryManager.tryToAcquire(anyLong())).then(returnsFirstArg()); |
| 86 | + when(blockManager.diskBlockManager()).thenReturn(diskBlockManager); |
| 87 | + when(diskBlockManager.createTempLocalBlock()).thenAnswer(new Answer<Tuple2<TempLocalBlockId, File>>() { |
| 88 | + @Override |
| 89 | + public Tuple2<TempLocalBlockId, File> answer(InvocationOnMock invocationOnMock) throws Throwable { |
| 90 | + TempLocalBlockId blockId = new TempLocalBlockId(UUID.randomUUID()); |
| 91 | + File file = File.createTempFile("spillFile", ".spill", tempDir); |
| 92 | + return Tuple2$.MODULE$.apply(blockId, file); |
| 93 | + } |
| 94 | + }); |
| 95 | + when(blockManager.getDiskWriter( |
| 96 | + any(BlockId.class), |
| 97 | + any(File.class), |
| 98 | + any(SerializerInstance.class), |
| 99 | + anyInt(), |
| 100 | + any(ShuffleWriteMetrics.class))).thenAnswer(new Answer<DiskBlockObjectWriter>() { |
| 101 | + @Override |
| 102 | + public DiskBlockObjectWriter answer(InvocationOnMock invocationOnMock) throws Throwable { |
| 103 | + Object[] args = invocationOnMock.getArguments(); |
| 104 | + |
| 105 | + return new DiskBlockObjectWriter( |
| 106 | + (BlockId) args[0], |
| 107 | + (File) args[1], |
| 108 | + (SerializerInstance) args[2], |
| 109 | + (Integer) args[3], |
| 110 | + new CompressStream(), |
| 111 | + false, |
| 112 | + (ShuffleWriteMetrics) args[4] |
| 113 | + ); |
| 114 | + } |
| 115 | + }); |
| 116 | + when(blockManager.wrapForCompression(any(BlockId.class), any(InputStream.class))) |
| 117 | + .then(returnsSecondArg()); |
| 118 | + } |
| 119 | + |
| 120 | + @Test |
| 121 | + public void basicShuffleWriting() throws Exception { |
| 122 | + |
| 123 | + final ShuffleDependency<Object, Object, Object> dep = mock(ShuffleDependency.class); |
| 124 | + when(dep.serializer()).thenReturn(Option.<Serializer>apply(new KryoSerializer(sparkConf))); |
| 125 | + when(dep.partitioner()).thenReturn(hashPartitioner); |
| 126 | + |
| 127 | + final File mergedOutputFile = File.createTempFile("mergedoutput", "", tempDir); |
| 128 | + when(shuffleBlockManager.getDataFile(anyInt(), anyInt())).thenReturn(mergedOutputFile); |
| 129 | + final long[] partitionSizes = new long[hashPartitioner.numPartitions()]; |
| 130 | + doAnswer(new Answer<Void>() { |
| 131 | + @Override |
| 132 | + public Void answer(InvocationOnMock invocationOnMock) throws Throwable { |
| 133 | + long[] receivedPartitionSizes = (long[]) invocationOnMock.getArguments()[2]; |
| 134 | + System.arraycopy( |
| 135 | + receivedPartitionSizes, 0, partitionSizes, 0, receivedPartitionSizes.length); |
| 136 | + return null; |
| 137 | + } |
| 138 | + }).when(shuffleBlockManager).writeIndexFile(anyInt(), anyInt(), any(long[].class)); |
| 139 | + |
| 140 | + final UnsafeShuffleWriter<Object, Object> writer = new UnsafeShuffleWriter<Object, Object>( |
| 141 | + blockManager, |
| 142 | + shuffleBlockManager, |
| 143 | + memoryManager, |
| 144 | + shuffleMemoryManager, |
| 145 | + new UnsafeShuffleHandle<Object, Object>(0, 1, dep), |
| 146 | + 0, // map id |
| 147 | + taskContext, |
| 148 | + sparkConf |
| 149 | + ); |
| 150 | + |
| 151 | + final ArrayList<Product2<Object, Object>> numbersToSort = |
| 152 | + new ArrayList<Product2<Object, Object>>(); |
| 153 | + numbersToSort.add(new Tuple2<Object, Object>(5, 5)); |
| 154 | + numbersToSort.add(new Tuple2<Object, Object>(1, 1)); |
| 155 | + numbersToSort.add(new Tuple2<Object, Object>(3, 3)); |
| 156 | + numbersToSort.add(new Tuple2<Object, Object>(2, 2)); |
| 157 | + numbersToSort.add(new Tuple2<Object, Object>(4, 4)); |
| 158 | + |
| 159 | + |
| 160 | + writer.write(numbersToSort.iterator()); |
| 161 | + final MapStatus mapStatus = writer.stop(true).get(); |
| 162 | + |
| 163 | + long sumOfPartitionSizes = 0; |
| 164 | + for (long size: partitionSizes) { |
| 165 | + sumOfPartitionSizes += size; |
| 166 | + } |
| 167 | + Assert.assertEquals(mergedOutputFile.length(), sumOfPartitionSizes); |
| 168 | + |
| 169 | + // TODO: test that the temporary spill files were cleaned up after the merge. |
| 170 | + } |
| 171 | + |
| 172 | +} |
0 commit comments