|
32 | 32 | import scala.reflect.ClassTag$; |
33 | 33 |
|
34 | 34 | import com.esotericsoftware.kryo.io.ByteBufferOutputStream; |
| 35 | +import com.google.common.annotations.VisibleForTesting; |
35 | 36 | import com.google.common.io.ByteStreams; |
36 | 37 | import com.google.common.io.Files; |
37 | 38 | import org.slf4j.Logger; |
@@ -73,6 +74,11 @@ public class UnsafeShuffleWriter<K, V> extends ShuffleWriter<K, V> { |
73 | 74 | private final boolean transferToEnabled; |
74 | 75 |
|
75 | 76 | private MapStatus mapStatus = null; |
| 77 | + private UnsafeShuffleExternalSorter sorter = null; |
| 78 | + private byte[] serArray = null; |
| 79 | + private ByteBuffer serByteBuffer; |
| 80 | + // TODO: we should not depend on this class from Kryo; copy its source or find an alternative |
| 81 | + private SerializationStream serOutputStream; |
76 | 82 |
|
77 | 83 | /** |
78 | 84 | * Are we in the process of stopping? Because map tasks can call stop() with success = true |
@@ -113,56 +119,72 @@ public void write(Iterator<Product2<K, V>> records) { |
113 | 119 | @Override |
114 | 120 | public void write(scala.collection.Iterator<Product2<K, V>> records) { |
115 | 121 | try { |
116 | | - final long[] partitionLengths = mergeSpills(insertRecordsIntoSorter(records)); |
117 | | - shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); |
118 | | - mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); |
| 122 | + while (records.hasNext()) { |
| 123 | + insertRecordIntoSorter(records.next()); |
| 124 | + } |
| 125 | + closeAndWriteOutput(); |
119 | 126 | } catch (Exception e) { |
120 | 127 | PlatformDependent.throwException(e); |
121 | 128 | } |
122 | 129 | } |
123 | 130 |
|
124 | | - private void freeMemory() { |
125 | | - // TODO |
126 | | - } |
127 | | - |
128 | | - private void deleteSpills() { |
129 | | - // TODO |
130 | | - } |
131 | | - |
132 | | - private SpillInfo[] insertRecordsIntoSorter( |
133 | | - scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception { |
134 | | - final UnsafeShuffleExternalSorter sorter = new UnsafeShuffleExternalSorter( |
| 131 | + private void open() throws IOException { |
| 132 | + assert (sorter == null); |
| 133 | + sorter = new UnsafeShuffleExternalSorter( |
135 | 134 | memoryManager, |
136 | 135 | shuffleMemoryManager, |
137 | 136 | blockManager, |
138 | 137 | taskContext, |
139 | 138 | 4096, // Initial size (TODO: tune this!) |
140 | 139 | partitioner.numPartitions(), |
141 | 140 | sparkConf); |
142 | | - |
143 | | - final byte[] serArray = new byte[SER_BUFFER_SIZE]; |
144 | | - final ByteBuffer serByteBuffer = ByteBuffer.wrap(serArray); |
| 141 | + serArray = new byte[SER_BUFFER_SIZE]; |
| 142 | + serByteBuffer = ByteBuffer.wrap(serArray); |
145 | 143 | // TODO: we should not depend on this class from Kryo; copy its source or find an alternative |
146 | | - final SerializationStream serOutputStream = |
147 | | - serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); |
| 144 | + serOutputStream = serializer.serializeStream(new ByteBufferOutputStream(serByteBuffer)); |
| 145 | + } |
148 | 146 |
|
149 | | - while (records.hasNext()) { |
150 | | - final Product2<K, V> record = records.next(); |
151 | | - final K key = record._1(); |
152 | | - final int partitionId = partitioner.getPartition(key); |
153 | | - serByteBuffer.position(0); |
154 | | - serOutputStream.writeKey(key, OBJECT_CLASS_TAG); |
155 | | - serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); |
156 | | - serOutputStream.flush(); |
| 147 | + @VisibleForTesting |
| 148 | + void closeAndWriteOutput() throws IOException { |
| 149 | + if (sorter == null) { |
| 150 | + open(); |
| 151 | + } |
| 152 | + serArray = null; |
| 153 | + serByteBuffer = null; |
| 154 | + serOutputStream = null; |
| 155 | + final long[] partitionLengths = mergeSpills(sorter.closeAndGetSpills()); |
| 156 | + sorter = null; |
| 157 | + shuffleBlockManager.writeIndexFile(shuffleId, mapId, partitionLengths); |
| 158 | + mapStatus = MapStatus$.MODULE$.apply(blockManager.shuffleServerId(), partitionLengths); |
| 159 | + } |
157 | 160 |
|
158 | | - final int serializedRecordSize = serByteBuffer.position(); |
159 | | - assert (serializedRecordSize > 0); |
| 161 | + private void freeMemory() { |
| 162 | + // TODO |
| 163 | + } |
160 | 164 |
|
161 | | - sorter.insertRecord( |
162 | | - serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); |
| 165 | + @VisibleForTesting |
| 166 | + void insertRecordIntoSorter(Product2<K, V> record) throws IOException{ |
| 167 | + if (sorter == null) { |
| 168 | + open(); |
163 | 169 | } |
| 170 | + final K key = record._1(); |
| 171 | + final int partitionId = partitioner.getPartition(key); |
| 172 | + serByteBuffer.position(0); |
| 173 | + serOutputStream.writeKey(key, OBJECT_CLASS_TAG); |
| 174 | + serOutputStream.writeValue(record._2(), OBJECT_CLASS_TAG); |
| 175 | + serOutputStream.flush(); |
164 | 176 |
|
165 | | - return sorter.closeAndGetSpills(); |
| 177 | + final int serializedRecordSize = serByteBuffer.position(); |
| 178 | + assert (serializedRecordSize > 0); |
| 179 | + |
| 180 | + sorter.insertRecord( |
| 181 | + serArray, PlatformDependent.BYTE_ARRAY_OFFSET, serializedRecordSize, partitionId); |
| 182 | + } |
| 183 | + |
| 184 | + @VisibleForTesting |
| 185 | + void forceSorterToSpill() throws IOException { |
| 186 | + assert (sorter != null); |
| 187 | + sorter.spill(); |
166 | 188 | } |
167 | 189 |
|
168 | 190 | private long[] mergeSpills(SpillInfo[] spills) throws IOException { |
@@ -222,6 +244,9 @@ private long[] mergeSpillsWithFileStream(SpillInfo[] spills, File outputFile) th |
222 | 244 | for (int i = 0; i < spills.length; i++) { |
223 | 245 | if (spillInputStreams[i] != null) { |
224 | 246 | spillInputStreams[i].close(); |
| 247 | + if (!spills[i].file.delete()) { |
| 248 | + logger.error("Error while deleting spill file {}", spills[i]); |
| 249 | + } |
225 | 250 | } |
226 | 251 | } |
227 | 252 | if (mergedFileOutputStream != null) { |
@@ -282,6 +307,9 @@ private long[] mergeSpillsWithTransferTo(SpillInfo[] spills, File outputFile) th |
282 | 307 | assert(spillInputChannelPositions[i] == spills[i].file.length()); |
283 | 308 | if (spillInputChannels[i] != null) { |
284 | 309 | spillInputChannels[i].close(); |
| 310 | + if (!spills[i].file.delete()) { |
| 311 | + logger.error("Error while deleting spill file {}", spills[i]); |
| 312 | + } |
285 | 313 | } |
286 | 314 | } |
287 | 315 | if (mergedFileOutputChannel != null) { |
|
0 commit comments