Skip to content
100 changes: 100 additions & 0 deletions core/src/main/java/org/apache/spark/memory/MemoryConsumer.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.spark.memory;


import java.io.IOException;

import org.apache.spark.unsafe.memory.MemoryBlock;


/**
* An memory consumer of TaskMemoryManager, which support spilling.
*/
public class MemoryConsumer {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the idea that each operator will have its own subclass of MemoryConsumer which implements spill()?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, we could make it abstract


private TaskMemoryManager memoryManager;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do you mind renaming this variable to taskMemoryManager in order to avoid any ambiguity?

private long pageSize;

protected MemoryConsumer(TaskMemoryManager memoryManager, long pageSize) {
this.memoryManager = memoryManager;
this.pageSize = pageSize;
}

protected MemoryConsumer(TaskMemoryManager memoryManager) {
this(memoryManager, memoryManager.pageSizeBytes());
}

/**
* Spill some data to disk to release memory, which will be called by TaskMemoryManager
* when there is not enough memory for the task.
*
* @param size the amount of memory should be released
* @return the amount of released memory in bytes
* @throws IOException
*/
public long spill(long size) throws IOException {
return 0L;
}

/**
* Acquire `size` bytes memory.
*
* If there is not enough memory, throws IOException.
*
* @throws IOException
*/
protected void acquireMemory(long size) throws IOException {
long got = memoryManager.acquireExecutionMemory(size, this);
if (got < size) {
throw new IOException("Could not acquire " + size + " bytes of memory " + got);
}
}

/**
* Release amount of memory.
*/
protected void releaseMemory(long size) {
memoryManager.releaseExecutionMemory(size, this);
}

/**
* Allocate a memory block with at least `required` bytes.
*
* Throws IOException if there is not enough memory.
*
* @throws IOException
*/
protected MemoryBlock allocatePage(long required) throws IOException {
MemoryBlock page = memoryManager.allocatePage(Math.max(pageSize, required), this);
if (page == null || page.size() < required) {
if (page != null) {
freePage(page);
}
throw new IOException("Unable to acquire " + required + " bytes of memory");
}
return page;
}

/**
* Free a memory block.
*/
protected void freePage(MemoryBlock page) {
memoryManager.freePage(page, this);
}
}
112 changes: 91 additions & 21 deletions core/src/main/java/org/apache/spark/memory/TaskMemoryManager.java
Original file line number Diff line number Diff line change
Expand Up @@ -17,13 +17,18 @@

package org.apache.spark.memory;

import java.util.*;
import java.io.IOException;
import java.util.BitSet;
import java.util.HashMap;

import com.google.common.annotations.VisibleForTesting;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import org.apache.spark.SparkException;
import org.apache.spark.unsafe.Platform;
import org.apache.spark.unsafe.memory.MemoryBlock;
import org.apache.spark.util.Utils;

/**
* Manages the memory allocated by an individual task.
Expand Down Expand Up @@ -100,28 +105,96 @@ public class TaskMemoryManager {
*/
private final boolean inHeap;

/**
* The size of memory granted to each consumer.
*/
private HashMap<MemoryConsumer, Long> consumers;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could be final.


/**
* Construct a new TaskMemoryManager.
*/
public TaskMemoryManager(MemoryManager memoryManager, long taskAttemptId) {
this.inHeap = memoryManager.tungstenMemoryIsAllocatedInHeap();
this.memoryManager = memoryManager;
this.taskAttemptId = taskAttemptId;
this.consumers = new HashMap<>();
}

/**
* Acquire N bytes of memory for execution, evicting cached blocks if necessary.
* Acquire N bytes of memory for a consumer. If there is no enough memory, it will call
* spill() of consumers to release more memory.
*
* @return number of bytes successfully granted (<= N).
*/
public long acquireExecutionMemory(long size) {
return memoryManager.acquireExecutionMemory(size, taskAttemptId);
public long acquireExecutionMemory(long size, MemoryConsumer consumer) throws IOException {
synchronized (this) {
long got = memoryManager.acquireExecutionMemory(size, taskAttemptId);

// call spill() on itself to release some memory
if (got < size && consumer != null) {
consumer.spill(size - got);
got += memoryManager.acquireExecutionMemory(size - got, taskAttemptId);
}

if (got < size) {
long needed = size - got;
// call spill() on other consumers to release memory
for (MemoryConsumer c: consumers.keySet()) {
if (c != null && c != consumer) {
needed -= c.spill(size - got);
if (needed < 0) {
break;
}
}
}
got += memoryManager.acquireExecutionMemory(size - got, taskAttemptId);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what if consumer is the only one in consumers ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Then this call will be no-op, could be avoided.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

My point was that got may not be equal to size coming out of the loop.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This function does not garantee that got will be equal to size, after call spill().

}

long old = 0L;
if (consumers.containsKey(consumer)) {
old = consumers.get(consumer);
}
consumers.put(consumer, got + old);

return got;
}
}

/**
* Release N bytes of execution memory.
* Release N bytes of execution memory for a MemoryConsumer.
*/
public void releaseExecutionMemory(long size) {
memoryManager.releaseExecutionMemory(size, taskAttemptId);
public void releaseExecutionMemory(long size, MemoryConsumer consumer) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add an assert to make sure size >= 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

if (size == 0) {
return;
}
synchronized (this) {
if (consumers.containsKey(consumer)) {
long old = consumers.get(consumer);
if (old > size) {
consumers.put(consumer, old - size);
} else {
if (old < size) {
if (Utils.isTesting()) {
Platform.throwException(
new SparkException("Release more memory " + size + "than acquired " + old + " for "
+ consumer));
} else {
logger.warn("Release more memory " + size + " than acquired " + old + "for "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Space before for

+ consumer);
}
}
consumers.remove(consumer);
}
} else {
if (Utils.isTesting()) {
Platform.throwException(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe add new method to Utils, accepting message String, which covers lines 191 to 196

new SparkException("Release memory " + size + " for not existed " + consumer));
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

not existed -> non-existent

} else {
logger.warn("Release memory " + size + " for not existed " + consumer);
}
}
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the else branch here supposed to be an error case?

memoryManager.releaseExecutionMemory(size, taskAttemptId);
}
}

public long pageSizeBytes() {
Expand All @@ -134,12 +207,17 @@ public long pageSizeBytes() {
*
* Returns `null` if there was not enough memory to allocate the page.
*/
public MemoryBlock allocatePage(long size) {
public MemoryBlock allocatePage(long size, MemoryConsumer consumer) throws IOException {
if (size > MAXIMUM_PAGE_SIZE_BYTES) {
throw new IllegalArgumentException(
"Cannot allocate a page with more than " + MAXIMUM_PAGE_SIZE_BYTES + " bytes");
}

long acquired = acquireExecutionMemory(size, consumer);
if (acquired <= 0) {
return null;
}

final int pageNumber;
synchronized (this) {
pageNumber = allocatedPages.nextClearBit(0);
Expand All @@ -149,14 +227,6 @@ public MemoryBlock allocatePage(long size) {
}
allocatedPages.set(pageNumber);
}
final long acquiredExecutionMemory = acquireExecutionMemory(size);
if (acquiredExecutionMemory != size) {
releaseExecutionMemory(acquiredExecutionMemory);
synchronized (this) {
allocatedPages.clear(pageNumber);
}
return null;
}
final MemoryBlock page = memoryManager.tungstenMemoryAllocator().allocate(size);
page.pageNumber = pageNumber;
pageTable[pageNumber] = page;
Expand All @@ -167,9 +237,9 @@ public MemoryBlock allocatePage(long size) {
}

/**
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage(long)}.
* Free a block of memory allocated via {@link TaskMemoryManager#allocatePage}.
*/
public void freePage(MemoryBlock page) {
public void freePage(MemoryBlock page, MemoryConsumer consumer) {
assert (page.pageNumber != -1) :
"Called freePage() on memory that wasn't allocated with allocatePage()";
assert(allocatedPages.get(page.pageNumber));
Expand All @@ -182,14 +252,14 @@ public void freePage(MemoryBlock page) {
}
long pageSize = page.size();
memoryManager.tungstenMemoryAllocator().free(page);
releaseExecutionMemory(pageSize);
releaseExecutionMemory(pageSize, consumer);
}

/**
* Given a memory page and offset within that page, encode this address into a 64-bit long.
* This address will remain valid as long as the corresponding page has not been freed.
*
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage(long)}/
* @param page a data page allocated by {@link TaskMemoryManager#allocatePage}/
* @param offsetInPage an offset in this page which incorporates the base offset. In other words,
* this should be the value that you would pass as the base offset into an
* UNSAFE call (e.g. page.baseOffset() + something).
Expand Down Expand Up @@ -265,7 +335,7 @@ public long cleanUpAllAllocatedMemory() {
for (MemoryBlock page : pageTable) {
if (page != null) {
freedBytes += page.size();
freePage(page);
freePage(page, null);
}
}

Expand Down
Loading