diff --git a/velox/exec/PartitionedOutputBufferManager.cpp b/velox/exec/PartitionedOutputBufferManager.cpp index 43138e650c7..783cfa6edc7 100644 --- a/velox/exec/PartitionedOutputBufferManager.cpp +++ b/velox/exec/PartitionedOutputBufferManager.cpp @@ -483,6 +483,14 @@ PartitionedOutputBufferManager::getBuffer(const std::string& taskId) { }); } +std::shared_ptr +PartitionedOutputBufferManager::getBufferIfExists(const std::string& taskId) { + return buffers_.withLock([&](auto& buffers) { + auto it = buffers.find(taskId); + return it == buffers.end() ? nullptr : it->second; + }); +} + uint64_t PartitionedOutputBufferManager::numBuffers() const { return buffers_.lock()->size(); } @@ -525,26 +533,22 @@ void PartitionedOutputBufferManager::acknowledge( void PartitionedOutputBufferManager::deleteResults( const std::string& taskId, int destination) { - auto buffer = buffers_.withLock( - [&](auto& buffers) -> std::shared_ptr { - auto it = buffers.find(taskId); - if (it == buffers.end()) { - return nullptr; - } - return it->second; - }); - if (buffer) { + if (auto buffer = getBufferIfExists(taskId)) { buffer->deleteResults(destination); } } -void PartitionedOutputBufferManager::getData( +bool PartitionedOutputBufferManager::getData( const std::string& taskId, int destination, uint64_t maxBytes, int64_t sequence, DataAvailableCallback notify) { - getBuffer(taskId)->getData(destination, maxBytes, sequence, notify); + if (auto buffer = getBufferIfExists(taskId)) { + buffer->getData(destination, maxBytes, sequence, notify); + return true; + } + return false; } void PartitionedOutputBufferManager::initializeTask( diff --git a/velox/exec/PartitionedOutputBufferManager.h b/velox/exec/PartitionedOutputBufferManager.h index 8f7dfd188b3..bc3bfb3ee9b 100644 --- a/velox/exec/PartitionedOutputBufferManager.h +++ b/velox/exec/PartitionedOutputBufferManager.h @@ -225,17 +225,19 @@ class PartitionedOutputBufferManager { void deleteResults(const std::string& taskId, int destination); // Adds up to 'maxBytes' bytes worth of data for 'destination' from - // 'taskId'. The sequence number of the data must be >= - // 'sequence'. If there is no data, 'notify' will be registered and - // called when there is data or the source is at end. Existing data - // with a sequence number < sequence is deleted. The caller is + // 'taskId'. The sequence number of the data must be >= 'sequence'. + // If there is no buffer associated with the given taskId, returns false. + // If there is no data, 'notify' will be registered and + // called when there is data or the source is at end, the function returns + // true. + // Existing data with a sequence number < sequence is deleted. The caller is // expected to increment the sequence number between calls by the // number of items received. In this way the next call implicitly // acknowledges receipt of the results from the previous. The // acknowledge method is offered for an early ack, so that the // producer can continue before the consumer is done processing the // received data. - void getData( + bool getData( const std::string& taskId, int destination, uint64_t maxBytes, @@ -264,8 +266,14 @@ class PartitionedOutputBufferManager { private: // Retrieves the set of buffers for a query. + // Throws an exception if buffer doesn't exist. std::shared_ptr getBuffer(const std::string& taskId); + // Retrieves the set of buffers for a query if exists. + // Returns NULL if task not found. + std::shared_ptr getBufferIfExists( + const std::string& taskId); + folly::Synchronized< std::unordered_map>, std::mutex> diff --git a/velox/exec/tests/PartitionedOutputBufferManagerTest.cpp b/velox/exec/tests/PartitionedOutputBufferManagerTest.cpp index a7fd810f704..ac2a1415fa7 100644 --- a/velox/exec/tests/PartitionedOutputBufferManagerTest.cpp +++ b/velox/exec/tests/PartitionedOutputBufferManagerTest.cpp @@ -108,7 +108,7 @@ class PartitionedOutputBufferManagerTest : public testing::Test { uint64_t maxBytes = 1024, int expectedGroups = 1) { bool receivedData = false; - bufferManager_->getData( + ASSERT_TRUE(bufferManager_->getData( taskId, destination, maxBytes, @@ -124,7 +124,7 @@ class PartitionedOutputBufferManagerTest : public testing::Test { } EXPECT_EQ(inSequence, sequence) << "for destination " << destination; receivedData = true; - }); + })); EXPECT_TRUE(receivedData) << "for destination " << destination; } @@ -163,12 +163,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test { void fetchEndMarker(const std::string& taskId, int destination, int64_t sequence) { bool receivedData = false; - bufferManager_->getData( + ASSERT_TRUE(bufferManager_->getData( taskId, destination, std::numeric_limits::max(), sequence, - receiveEndMarker(destination, sequence, receivedData)); + receiveEndMarker(destination, sequence, receivedData))); EXPECT_TRUE(receivedData) << "for destination " << destination; bufferManager_->deleteResults(taskId, destination); } @@ -183,12 +183,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test { int64_t sequence, bool& receivedEndMarker) { receivedEndMarker = false; - bufferManager_->getData( + ASSERT_TRUE(bufferManager_->getData( taskId, destination, std::numeric_limits::max(), sequence, - receiveEndMarker(destination, 1, receivedEndMarker)); + receiveEndMarker(destination, 1, receivedEndMarker))); EXPECT_FALSE(receivedEndMarker) << "for destination " << destination; } @@ -219,12 +219,12 @@ class PartitionedOutputBufferManagerTest : public testing::Test { int expectedGroups, bool& receivedData) { receivedData = false; - bufferManager_->getData( + ASSERT_TRUE(bufferManager_->getData( taskId, destination, 1024, sequence, - receiveData(destination, sequence, expectedGroups, receivedData)); + receiveData(destination, sequence, expectedGroups, receivedData))); EXPECT_FALSE(receivedData) << "for destination " << destination; } @@ -420,3 +420,17 @@ TEST_F(PartitionedOutputBufferManagerTest, serializedPage) { EXPECT_EQ(mappedMemory->allocateBytesStats().totalSmall, 0); } } + +TEST_F(PartitionedOutputBufferManagerTest, getDataOnFailedTask) { + // Fetching data on a task which was either never initialized in the buffer + // manager or was removed by a parallel thread must return false. The `notify` + // callback must not be registered. + ASSERT_FALSE(bufferManager_->getData( + "test.0.1", + 1, + 10, + 1, + [](std::vector> pages, int64_t sequence) { + VELOX_UNREACHABLE(); + })); +}