Skip to content

Commit f156a8f

Browse files
committed
Hacky metrics integration; refactor some interfaces.
1 parent 2776aca commit f156a8f

File tree

8 files changed

+151
-45
lines changed

8 files changed

+151
-45
lines changed

core/src/main/java/org/apache/spark/shuffle/unsafe/UnsafeShuffleWriter.java

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
package org.apache.spark.shuffle.unsafe;
1919

2020
import org.apache.spark.*;
21-
import org.apache.spark.unsafe.sort.UnsafeExternalSortSpillMerger;
21+
import org.apache.spark.unsafe.sort.ExternalSorterIterator;
2222
import org.apache.spark.unsafe.sort.UnsafeExternalSorter;
2323
import scala.Option;
2424
import scala.Product2;
@@ -28,7 +28,6 @@
2828
import java.io.File;
2929
import java.io.IOException;
3030
import java.nio.ByteBuffer;
31-
import java.util.Iterator;
3231
import java.util.LinkedList;
3332

3433
import com.esotericsoftware.kryo.io.ByteBufferOutputStream;
@@ -47,7 +46,7 @@
4746
import org.apache.spark.unsafe.PlatformDependent;
4847
import org.apache.spark.unsafe.memory.MemoryBlock;
4948
import org.apache.spark.unsafe.memory.TaskMemoryManager;
50-
import org.apache.spark.unsafe.sort.UnsafeSorter;
49+
5150
import static org.apache.spark.unsafe.sort.UnsafeSorter.*;
5251

5352
// IntelliJ gets confused and claims that this class should be abstract, but this actually compiles
@@ -64,7 +63,6 @@ public class UnsafeShuffleWriter<K, V> implements ShuffleWriter<K, V> {
6463
private final SerializerInstance serializer;
6564
private final Partitioner partitioner;
6665
private final ShuffleWriteMetrics writeMetrics;
67-
private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
6866
private final int fileBufferSize;
6967
private MapStatus mapStatus = null;
7068

@@ -108,12 +106,13 @@ private void freeMemory() {
108106
// TODO: free sorter memory
109107
}
110108

111-
private Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortRecords(
109+
private ExternalSorterIterator sortRecords(
112110
scala.collection.Iterator<? extends Product2<K, V>> records) throws Exception {
113111
final UnsafeExternalSorter sorter = new UnsafeExternalSorter(
114112
memoryManager,
115113
SparkEnv$.MODULE$.get().shuffleMemoryManager(),
116114
SparkEnv$.MODULE$.get().blockManager(),
115+
TaskContext.get(),
117116
RECORD_COMPARATOR,
118117
PREFIX_COMPARATOR,
119118
4096, // Initial size (TODO: tune this!)
@@ -145,8 +144,7 @@ private Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortRe
145144
return sorter.getSortedIterator();
146145
}
147146

148-
private long[] writeSortedRecordsToFile(
149-
Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> sortedRecords) throws IOException {
147+
private long[] writeSortedRecordsToFile(ExternalSorterIterator sortedRecords) throws IOException {
150148
final File outputFile = shuffleBlockManager.getDataFile(shuffleId, mapId);
151149
final ShuffleBlockId blockId =
152150
new ShuffleBlockId(shuffleId, mapId, IndexShuffleBlockManager.NOOP_REDUCE_ID());
@@ -157,8 +155,8 @@ private long[] writeSortedRecordsToFile(
157155

158156
final byte[] arr = new byte[SER_BUFFER_SIZE];
159157
while (sortedRecords.hasNext()) {
160-
final UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix recordPointer = sortedRecords.next();
161-
final int partition = (int) recordPointer.keyPrefix;
158+
sortedRecords.loadNext();
159+
final int partition = (int) sortedRecords.keyPrefix;
162160
assert (partition >= currentPartition);
163161
if (partition != currentPartition) {
164162
// Switch to the new partition
@@ -172,13 +170,13 @@ private long[] writeSortedRecordsToFile(
172170
}
173171

174172
PlatformDependent.copyMemory(
175-
recordPointer.baseObject,
176-
recordPointer.baseOffset + 4,
173+
sortedRecords.baseObject,
174+
sortedRecords.baseOffset + 4,
177175
arr,
178176
PlatformDependent.BYTE_ARRAY_OFFSET,
179-
recordPointer.recordLength);
177+
sortedRecords.recordLength);
180178
assert (writer != null); // To suppress an IntelliJ warning
181-
writer.write(arr, 0, recordPointer.recordLength);
179+
writer.write(arr, 0, sortedRecords.recordLength);
182180
// TODO: add a test that detects whether we leave this call out:
183181
writer.recordWritten();
184182
}
Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,31 @@
1+
/*
2+
* Licensed to the Apache Software Foundation (ASF) under one or more
3+
* contributor license agreements. See the NOTICE file distributed with
4+
* this work for additional information regarding copyright ownership.
5+
* The ASF licenses this file to You under the Apache License, Version 2.0
6+
* (the "License"); you may not use this file except in compliance with
7+
* the License. You may obtain a copy of the License at
8+
*
9+
* http://www.apache.org/licenses/LICENSE-2.0
10+
*
11+
* Unless required by applicable law or agreed to in writing, software
12+
* distributed under the License is distributed on an "AS IS" BASIS,
13+
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
* See the License for the specific language governing permissions and
15+
* limitations under the License.
16+
*/
17+
18+
package org.apache.spark.unsafe.sort;
19+
20+
public abstract class ExternalSorterIterator {
21+
22+
public Object baseObject;
23+
public long baseOffset;
24+
public int recordLength;
25+
public long keyPrefix;
26+
27+
public abstract boolean hasNext();
28+
29+
public abstract void loadNext();
30+
31+
}

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeExternalSorter.java

Lines changed: 67 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,15 @@
1919

2020
import com.google.common.annotations.VisibleForTesting;
2121
import org.apache.spark.SparkConf;
22+
import org.apache.spark.TaskContext;
2223
import org.apache.spark.executor.ShuffleWriteMetrics;
2324
import org.apache.spark.shuffle.ShuffleMemoryManager;
2425
import org.apache.spark.storage.BlockManager;
2526
import org.apache.spark.unsafe.PlatformDependent;
2627
import org.apache.spark.unsafe.memory.MemoryBlock;
2728
import org.apache.spark.unsafe.memory.TaskMemoryManager;
29+
import org.slf4j.Logger;
30+
import org.slf4j.LoggerFactory;
2831

2932
import java.io.IOException;
3033
import java.util.Iterator;
@@ -37,16 +40,20 @@
3740
*/
3841
public final class UnsafeExternalSorter {
3942

43+
private final Logger logger = LoggerFactory.getLogger(UnsafeExternalSorter.class);
44+
4045
private static final int PAGE_SIZE = 1024 * 1024; // TODO: tune this
4146

4247
private final PrefixComparator prefixComparator;
4348
private final RecordComparator recordComparator;
4449
private final int initialSize;
50+
private int numSpills = 0;
4551
private UnsafeSorter sorter;
4652

4753
private final TaskMemoryManager memoryManager;
4854
private final ShuffleMemoryManager shuffleMemoryManager;
4955
private final BlockManager blockManager;
56+
private final TaskContext taskContext;
5057
private final LinkedList<MemoryBlock> allocatedPages = new LinkedList<MemoryBlock>();
5158
private final boolean spillingEnabled;
5259
private final int fileBufferSize;
@@ -63,13 +70,15 @@ public UnsafeExternalSorter(
6370
TaskMemoryManager memoryManager,
6471
ShuffleMemoryManager shuffleMemoryManager,
6572
BlockManager blockManager,
73+
TaskContext taskContext,
6674
RecordComparator recordComparator,
6775
PrefixComparator prefixComparator,
6876
int initialSize,
69-
SparkConf conf) {
77+
SparkConf conf) throws IOException {
7078
this.memoryManager = memoryManager;
7179
this.shuffleMemoryManager = shuffleMemoryManager;
7280
this.blockManager = blockManager;
81+
this.taskContext = taskContext;
7382
this.recordComparator = recordComparator;
7483
this.prefixComparator = prefixComparator;
7584
this.initialSize = initialSize;
@@ -81,9 +90,19 @@ public UnsafeExternalSorter(
8190

8291
// TODO: metrics tracking + integration with shuffle write metrics
8392

84-
private void openSorter() {
93+
private void openSorter() throws IOException {
8594
this.writeMetrics = new ShuffleWriteMetrics();
8695
// TODO: connect write metrics to task metrics?
96+
// TODO: move this sizing calculation logic into a static method of sorter:
97+
final long memoryRequested = initialSize * 8L * 2;
98+
if (spillingEnabled) {
99+
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryRequested);
100+
if (memoryAcquired != memoryRequested) {
101+
shuffleMemoryManager.release(memoryAcquired);
102+
throw new IOException("Could not acquire memory!");
103+
}
104+
}
105+
87106
this.sorter = new UnsafeSorter(memoryManager, recordComparator, prefixComparator, initialSize);
88107
}
89108

@@ -101,23 +120,52 @@ public void spill() throws IOException {
101120
spillWriter.write(baseObject, baseOffset, recordLength, recordPointer.keyPrefix);
102121
}
103122
spillWriter.close();
123+
final long sorterMemoryUsage = sorter.getMemoryUsage();
104124
sorter = null;
105-
freeMemory();
125+
shuffleMemoryManager.release(sorterMemoryUsage);
126+
final long spillSize = freeMemory();
127+
taskContext.taskMetrics().incMemoryBytesSpilled(spillSize);
128+
taskContext.taskMetrics().incDiskBytesSpilled(spillWriter.numberOfSpilledBytes());
129+
numSpills++;
130+
final long threadId = Thread.currentThread().getId();
131+
// TODO: messy; log _before_ spill
132+
logger.info("Thread " + threadId + " spilling in-memory map of " +
133+
org.apache.spark.util.Utils.bytesToString(spillSize) + " to disk (" +
134+
(numSpills + ((numSpills > 1) ? " times" : " time")) + " so far)");
106135
openSorter();
107136
}
108137

109-
private void freeMemory() {
138+
private long freeMemory() {
139+
long memoryFreed = 0;
110140
final Iterator<MemoryBlock> iter = allocatedPages.iterator();
111141
while (iter.hasNext()) {
112142
memoryManager.freePage(iter.next());
113143
shuffleMemoryManager.release(PAGE_SIZE);
144+
memoryFreed += PAGE_SIZE;
114145
iter.remove();
115146
}
116147
currentPage = null;
117148
currentPagePosition = -1;
149+
return memoryFreed;
118150
}
119151

120152
private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
153+
// TODO: merge these steps to first calculate total memory requirements for this insert,
154+
// then try to acquire; no point in acquiring sort buffer only to spill due to no space in the
155+
// data page.
156+
if (!sorter.hasSpaceForAnotherRecord() && spillingEnabled) {
157+
final long oldSortBufferMemoryUsage = sorter.getMemoryUsage();
158+
final long memoryToGrowSortBuffer = oldSortBufferMemoryUsage * 2;
159+
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(memoryToGrowSortBuffer);
160+
if (memoryAcquired < memoryToGrowSortBuffer) {
161+
shuffleMemoryManager.release(memoryAcquired);
162+
spill();
163+
} else {
164+
sorter.expandSortBuffer();
165+
shuffleMemoryManager.release(oldSortBufferMemoryUsage);
166+
}
167+
}
168+
121169
final long spaceInCurrentPage;
122170
if (currentPage != null) {
123171
spaceInCurrentPage = PAGE_SIZE - (currentPagePosition - currentPage.getBaseOffset());
@@ -129,12 +177,22 @@ private void ensureSpaceInDataPage(int requiredSpace) throws Exception {
129177
throw new Exception("Required space " + requiredSpace + " is greater than page size (" +
130178
PAGE_SIZE + ")");
131179
} else if (requiredSpace > spaceInCurrentPage) {
132-
if (spillingEnabled && shuffleMemoryManager.tryToAcquire(PAGE_SIZE) < PAGE_SIZE) {
133-
spill();
180+
if (spillingEnabled) {
181+
final long memoryAcquired = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
182+
if (memoryAcquired < PAGE_SIZE) {
183+
shuffleMemoryManager.release(memoryAcquired);
184+
spill();
185+
final long memoryAcquiredAfterSpill = shuffleMemoryManager.tryToAcquire(PAGE_SIZE);
186+
if (memoryAcquiredAfterSpill != PAGE_SIZE) {
187+
shuffleMemoryManager.release(memoryAcquiredAfterSpill);
188+
throw new Exception("Can't allocate memory!");
189+
}
190+
}
134191
}
135192
currentPage = memoryManager.allocatePage(PAGE_SIZE);
136193
currentPagePosition = currentPage.getBaseOffset();
137194
allocatedPages.add(currentPage);
195+
logger.info("Acquired new page! " + allocatedPages.size() * PAGE_SIZE);
138196
}
139197
}
140198

@@ -162,9 +220,9 @@ public void insertRecord(
162220
sorter.insertRecord(recordAddress, prefix);
163221
}
164222

165-
public Iterator<UnsafeExternalSortSpillMerger.RecordAddressAndKeyPrefix> getSortedIterator() throws IOException {
166-
final UnsafeExternalSortSpillMerger spillMerger =
167-
new UnsafeExternalSortSpillMerger(recordComparator, prefixComparator);
223+
public ExternalSorterIterator getSortedIterator() throws IOException {
224+
final UnsafeSorterSpillMerger spillMerger =
225+
new UnsafeSorterSpillMerger(recordComparator, prefixComparator);
168226
for (UnsafeSorterSpillWriter spillWriter : spillWriters) {
169227
spillMerger.addSpill(spillWriter.getReader(blockManager));
170228
}

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorter.java

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -86,10 +86,9 @@ public static abstract class PrefixComparator {
8686
*/
8787
private int sortBufferInsertPosition = 0;
8888

89-
private void expandSortBuffer(int newSize) {
90-
assert (newSize > sortBuffer.length);
89+
public void expandSortBuffer() {
9190
final long[] oldBuffer = sortBuffer;
92-
sortBuffer = new long[newSize];
91+
sortBuffer = new long[oldBuffer.length * 2];
9392
System.arraycopy(oldBuffer, 0, sortBuffer, 0, oldBuffer.length);
9493
}
9594

@@ -122,14 +121,22 @@ public int compare(RecordPointerAndKeyPrefix left, RecordPointerAndKeyPrefix rig
122121
};
123122
}
124123

124+
public long getMemoryUsage() {
125+
return sortBuffer.length * 8L;
126+
}
127+
128+
public boolean hasSpaceForAnotherRecord() {
129+
return sortBufferInsertPosition + 2 < sortBuffer.length;
130+
}
131+
125132
/**
126133
* Insert a record into the sort buffer.
127134
*
128135
* @param objectAddress pointer to a record in a data page, encoded by {@link TaskMemoryManager}.
129136
*/
130137
public void insertRecord(long objectAddress, long keyPrefix) {
131-
if (sortBufferInsertPosition + 2 == sortBuffer.length) {
132-
expandSortBuffer(sortBuffer.length * 2);
138+
if (!hasSpaceForAnotherRecord()) {
139+
expandSortBuffer();
133140
}
134141
sortBuffer[sortBufferInsertPosition] = objectAddress;
135142
sortBufferInsertPosition++;
@@ -167,10 +174,10 @@ public void remove() {
167174
};
168175
}
169176

170-
public UnsafeExternalSortSpillMerger.MergeableIterator getMergeableIterator() {
177+
public UnsafeSorterSpillMerger.MergeableIterator getMergeableIterator() {
171178
sorter.sort(sortBuffer, 0, sortBufferInsertPosition / 2, sortComparator);
172-
UnsafeExternalSortSpillMerger.MergeableIterator iter =
173-
new UnsafeExternalSortSpillMerger.MergeableIterator() {
179+
UnsafeSorterSpillMerger.MergeableIterator iter =
180+
new UnsafeSorterSpillMerger.MergeableIterator() {
174181

175182
private int position = 0;
176183
private Object baseObject;
Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323

2424
import static org.apache.spark.unsafe.sort.UnsafeSorter.*;
2525

26-
public final class UnsafeExternalSortSpillMerger {
26+
public final class UnsafeSorterSpillMerger {
2727

2828
private final PriorityQueue<MergeableIterator> priorityQueue;
2929

@@ -46,9 +46,9 @@ public static final class RecordAddressAndKeyPrefix {
4646
public long keyPrefix;
4747
}
4848

49-
public UnsafeExternalSortSpillMerger(
50-
final RecordComparator recordComparator,
51-
final UnsafeSorter.PrefixComparator prefixComparator) {
49+
public UnsafeSorterSpillMerger(
50+
final RecordComparator recordComparator,
51+
final UnsafeSorter.PrefixComparator prefixComparator) {
5252
final Comparator<MergeableIterator> comparator = new Comparator<MergeableIterator>() {
5353

5454
@Override

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillReader.java

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424

2525
import java.io.*;
2626

27-
public final class UnsafeSorterSpillReader extends UnsafeExternalSortSpillMerger.MergeableIterator {
27+
final class UnsafeSorterSpillReader extends UnsafeSorterSpillMerger.MergeableIterator {
2828

2929
private final File file;
3030
private InputStream in;

core/src/main/java/org/apache/spark/unsafe/sort/UnsafeSorterSpillWriter.java

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@
3333
import java.io.*;
3434
import java.nio.ByteBuffer;
3535

36-
public final class UnsafeSorterSpillWriter {
36+
final class UnsafeSorterSpillWriter {
3737

3838
private static final int SER_BUFFER_SIZE = 1024 * 1024; // TODO: tune this
3939
public static final int EOF_MARKER = -1;
@@ -122,6 +122,10 @@ public void close() throws IOException {
122122
arr = null;
123123
}
124124

125+
public long numberOfSpilledBytes() {
126+
return file.length();
127+
}
128+
125129
public UnsafeSorterSpillReader getReader(BlockManager blockManager) throws IOException {
126130
return new UnsafeSorterSpillReader(blockManager, file, blockId);
127131
}

0 commit comments

Comments
 (0)