Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 15 additions & 11 deletions velox/exec/PartitionedOutputBufferManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -483,6 +483,14 @@ PartitionedOutputBufferManager::getBuffer(const std::string& taskId) {
});
}

std::shared_ptr<PartitionedOutputBuffer>
PartitionedOutputBufferManager::getBufferIfExists(const std::string& taskId) {
return buffers_.withLock([&](auto& buffers) {
auto it = buffers.find(taskId);
return it == buffers.end() ? nullptr : it->second;
});
}

uint64_t PartitionedOutputBufferManager::numBuffers() const {
return buffers_.lock()->size();
}
Expand Down Expand Up @@ -525,26 +533,22 @@ void PartitionedOutputBufferManager::acknowledge(
void PartitionedOutputBufferManager::deleteResults(
const std::string& taskId,
int destination) {
auto buffer = buffers_.withLock(
[&](auto& buffers) -> std::shared_ptr<PartitionedOutputBuffer> {
auto it = buffers.find(taskId);
if (it == buffers.end()) {
return nullptr;
}
return it->second;
});
if (buffer) {
if (auto buffer = getBufferIfExists(taskId)) {
buffer->deleteResults(destination);
}
}

void PartitionedOutputBufferManager::getData(
bool PartitionedOutputBufferManager::getData(
const std::string& taskId,
int destination,
uint64_t maxBytes,
int64_t sequence,
DataAvailableCallback notify) {
getBuffer(taskId)->getData(destination, maxBytes, sequence, notify);
if (auto buffer = getBufferIfExists(taskId)) {
buffer->getData(destination, maxBytes, sequence, notify);
return true;
}
return false;
}

void PartitionedOutputBufferManager::initializeTask(
Expand Down
18 changes: 13 additions & 5 deletions velox/exec/PartitionedOutputBufferManager.h
Original file line number Diff line number Diff line change
Expand Up @@ -225,17 +225,19 @@ class PartitionedOutputBufferManager {
void deleteResults(const std::string& taskId, int destination);

// Adds up to 'maxBytes' bytes worth of data for 'destination' from
// 'taskId'. The sequence number of the data must be >=
// 'sequence'. If there is no data, 'notify' will be registered and
// called when there is data or the source is at end. Existing data
// with a sequence number < sequence is deleted. The caller is
// 'taskId'. The sequence number of the data must be >= 'sequence'.
// If there is no buffer associated with the given taskId, returns false.
// If there is no data, 'notify' will be registered and
// called when there is data or the source is at end, the function returns
// true.
// Existing data with a sequence number < sequence is deleted. The caller is
// expected to increment the sequence number between calls by the
// number of items received. In this way the next call implicitly
// acknowledges receipt of the results from the previous. The
// acknowledge method is offered for an early ack, so that the
// producer can continue before the consumer is done processing the
// received data.
void getData(
bool getData(
const std::string& taskId,
int destination,
uint64_t maxBytes,
Expand Down Expand Up @@ -264,8 +266,14 @@ class PartitionedOutputBufferManager {

private:
// Retrieves the set of buffers for a query.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's clarify that this method throws is buffer for the specified task does not exist.

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I still think we should clarify that this method throws when task is not found.

// Retrieves the set of buffers for the specified task. Throws buffers-no-found exception if not found.

// Throws an exception if buffer doesn't exist.
std::shared_ptr<PartitionedOutputBuffer> getBuffer(const std::string& taskId);

// Retrieves the set of buffers for a query if exists.
// Returns NULL if task not found.
std::shared_ptr<PartitionedOutputBuffer> getBufferIfExists(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is a nice method. Please, add empty line before this method and document it. Let's use this method in PartitionedOutputBufferManager::deleteResults as well.

const std::string& taskId);

folly::Synchronized<
std::unordered_map<std::string, std::shared_ptr<PartitionedOutputBuffer>>,
std::mutex>
Expand Down
30 changes: 22 additions & 8 deletions velox/exec/tests/PartitionedOutputBufferManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ class PartitionedOutputBufferManagerTest : public testing::Test {
uint64_t maxBytes = 1024,
int expectedGroups = 1) {
bool receivedData = false;
bufferManager_->getData(
ASSERT_TRUE(bufferManager_->getData(
taskId,
destination,
maxBytes,
Expand All @@ -124,7 +124,7 @@ class PartitionedOutputBufferManagerTest : public testing::Test {
}
EXPECT_EQ(inSequence, sequence) << "for destination " << destination;
receivedData = true;
});
}));
EXPECT_TRUE(receivedData) << "for destination " << destination;
}

Expand Down Expand Up @@ -163,12 +163,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test {
void
fetchEndMarker(const std::string& taskId, int destination, int64_t sequence) {
bool receivedData = false;
bufferManager_->getData(
ASSERT_TRUE(bufferManager_->getData(
taskId,
destination,
std::numeric_limits<uint64_t>::max(),
sequence,
receiveEndMarker(destination, sequence, receivedData));
receiveEndMarker(destination, sequence, receivedData)));
EXPECT_TRUE(receivedData) << "for destination " << destination;
bufferManager_->deleteResults(taskId, destination);
}
Expand All @@ -183,12 +183,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test {
int64_t sequence,
bool& receivedEndMarker) {
receivedEndMarker = false;
bufferManager_->getData(
ASSERT_TRUE(bufferManager_->getData(
taskId,
destination,
std::numeric_limits<uint64_t>::max(),
sequence,
receiveEndMarker(destination, 1, receivedEndMarker));
receiveEndMarker(destination, 1, receivedEndMarker)));
EXPECT_FALSE(receivedEndMarker) << "for destination " << destination;
}

Expand Down Expand Up @@ -219,12 +219,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test {
int expectedGroups,
bool& receivedData) {
receivedData = false;
bufferManager_->getData(
ASSERT_TRUE(bufferManager_->getData(
taskId,
destination,
1024,
sequence,
receiveData(destination, sequence, expectedGroups, receivedData));
receiveData(destination, sequence, expectedGroups, receivedData)));
EXPECT_FALSE(receivedData) << "for destination " << destination;
}

Expand Down Expand Up @@ -420,3 +420,17 @@ TEST_F(PartitionedOutputBufferManagerTest, serializedPage) {
EXPECT_EQ(mappedMemory->allocateBytesStats().totalSmall, 0);
}
}

TEST_F(PartitionedOutputBufferManagerTest, getDataOnFailedTask) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's update existing calls to getData to assert that they return true.

// Fetching data on a task which was either never initialized in the buffer
// manager or was removed by a parallel thread must return false. The `notify`
// callback must not be registered.
ASSERT_FALSE(bufferManager_->getData(
"test.0.1",
1,
10,
1,
[](std::vector<std::unique_ptr<folly::IOBuf>> pages, int64_t sequence) {
VELOX_UNREACHABLE();
}));
}