Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions velox/exec/Exchange.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ std::shared_ptr<ExchangeSource> ExchangeSource::create(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool) {
memory::MemoryPool* pool) {
for (auto& factory : factories()) {
auto result = factory(taskId, destination, queue, pool);
if (result) {
Expand All @@ -83,7 +83,7 @@ class LocalExchangeSource : public ExchangeSource {
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)
memory::MemoryPool* pool)
: ExchangeSource(taskId, destination, queue, pool) {}

bool shouldRequestLocked() override {
Expand Down Expand Up @@ -133,16 +133,22 @@ class LocalExchangeSource : public ExchangeSource {
}
int64_t ackSequence;
{
std::lock_guard<std::mutex> l(queue_->mutex());
requestPending_ = false;
for (auto& page : pages) {
queue_->enqueue(std::move(page));
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(queue_->mutex());
requestPending_ = false;
for (auto& page : pages) {
queue_->enqueueLocked(std::move(page), promises);
}
if (atEnd) {
queue_->enqueueLocked(nullptr, promises);
atEnd_ = true;
}
ackSequence = sequence_ = sequence + pages.size();
}
if (atEnd) {
queue_->enqueue(nullptr);
atEnd_ = true;
for (auto& promise : promises) {
promise.setValue();
}
ackSequence = sequence_ = sequence + pages.size();
}
// Outside of queue mutex.
if (atEnd_) {
Expand All @@ -166,7 +172,7 @@ std::unique_ptr<ExchangeSource> createLocalExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool) {
memory::MemoryPool* pool) {
if (strncmp(taskId.c_str(), "local://", 8) == 0) {
return std::make_unique<LocalExchangeSource>(
taskId, destination, std::move(queue), pool);
Expand Down Expand Up @@ -194,7 +200,7 @@ void ExchangeClient::addRemoteTaskId(const std::string& taskId) {
toClose = std::move(source);
} else {
sources_.push_back(source);
queue_->addSource();
queue_->addSourceLocked();
if (source->shouldRequestLocked()) {
toRequest = source;
}
Expand All @@ -210,7 +216,6 @@ void ExchangeClient::addRemoteTaskId(const std::string& taskId) {
}

void ExchangeClient::noMoreRemoteTasks() {
std::lock_guard<std::mutex> l(queue_->mutex());
queue_->noMoreSources();
}

Expand All @@ -229,10 +234,7 @@ void ExchangeClient::close() {
for (auto& source : sources) {
source->close();
}
{
std::lock_guard<std::mutex> l(queue_->mutex());
queue_->closeLocked();
}
queue_->close();
}

std::unique_ptr<SerializedPage> ExchangeClient::next(
Expand All @@ -243,7 +245,7 @@ std::unique_ptr<SerializedPage> ExchangeClient::next(
{
std::lock_guard<std::mutex> l(queue_->mutex());
*atEnd = false;
page = queue_->dequeue(atEnd, future);
page = queue_->dequeueLocked(atEnd, future);
if (*atEnd) {
return page;
}
Expand Down Expand Up @@ -278,7 +280,7 @@ std::string ExchangeClient::toString() {
return out.str();
}

bool Exchange::getSplits(ContinueFuture* FOLLY_NONNULL future) {
bool Exchange::getSplits(ContinueFuture* future) {
if (operatorCtx_->driverCtx()->driverId != 0) {
// When there are multiple pipelines, a single operator, the one from
// pipeline 0, is responsible for feeding splits into shared ExchangeClient.
Expand Down Expand Up @@ -313,7 +315,7 @@ bool Exchange::getSplits(ContinueFuture* FOLLY_NONNULL future) {
}
}

BlockingReason Exchange::isBlocked(ContinueFuture* FOLLY_NONNULL future) {
BlockingReason Exchange::isBlocked(ContinueFuture* future) {
if (currentPage_ || atEnd_) {
return BlockingReason::kNotBlocked;
}
Expand Down
115 changes: 75 additions & 40 deletions velox/exec/Exchange.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ class SerializedPage {
// TODO: consider to enforce setting memory pool if possible.
explicit SerializedPage(
std::unique_ptr<folly::IOBuf> iobuf,
memory::MemoryPool* FOLLY_NULLABLE pool = nullptr,
memory::MemoryPool* pool = nullptr,
std::function<void(folly::IOBuf&)> onDestructionCb = nullptr);

~SerializedPage();
Expand All @@ -47,7 +47,7 @@ class SerializedPage {

// Makes 'input' ready for deserializing 'this' with
// VectorStreamGroup::read().
void prepareStreamForDeserialize(ByteStream* FOLLY_NONNULL input);
void prepareStreamForDeserialize(ByteStream* input);

std::unique_ptr<folly::IOBuf> getIOBuf() const {
return iobuf_->clone();
Expand All @@ -70,7 +70,7 @@ class SerializedPage {

// Number of payload bytes in 'iobuf_'.
const int64_t iobufBytes_;
memory::MemoryPool* FOLLY_NULLABLE pool_;
memory::MemoryPool* pool_;

// Callback that will be called on destruction of the SerializedPage,
// primarily used to free externally allocated memory backing folly::IOBuf
Expand All @@ -87,7 +87,6 @@ class ExchangeQueue {
explicit ExchangeQueue(int64_t minBytes) : minBytes_(minBytes) {}

~ExchangeQueue() {
std::lock_guard<std::mutex> l(mutex_);
clearAllPromises();
}

Expand All @@ -99,36 +98,47 @@ class ExchangeQueue {
return queue_.empty();
}

void enqueue(std::unique_ptr<SerializedPage>&& page) {
if (!page) {
void enqueueLocked(
std::unique_ptr<SerializedPage>&& page,
std::vector<ContinuePromise>& promises) {
if (page == nullptr) {
++numCompleted_;
checkComplete();
auto completedPromises = checkCompleteLocked();
promises.reserve(promises.size() + completedPromises.size());
for (auto& promise : completedPromises) {
promises.push_back(std::move(promise));
}
return;
}
totalBytes_ += page->size();
queue_.push_back(std::move(page));
if (!promises_.empty()) {
// Resume one of the waiting drivers.
promises_.back().setValue();
promises.push_back(std::move(promises_.back()));
promises_.pop_back();
}
}

// If data is permanently not available, e.g. the source cannot be
// contacted, this registers an error message and causes the reading
// Exchanges to throw with the message
void setErrorLocked(const std::string& error) {
if (!error_.empty()) {
return;
void setError(const std::string& error) {
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
if (!error_.empty()) {
return;
}
error_ = error;
atEnd_ = true;
promises = clearAllPromisesLocked();
}
error_ = error;
atEnd_ = true;
clearAllPromises();
clearPromises(promises);
}

std::unique_ptr<SerializedPage> dequeue(
bool* FOLLY_NONNULL atEnd,
ContinueFuture* FOLLY_NONNULL future) {
std::unique_ptr<SerializedPage> dequeueLocked(
bool* atEnd,
ContinueFuture* future) {
VELOX_CHECK(future);
if (!error_.empty()) {
*atEnd = true;
Expand Down Expand Up @@ -163,34 +173,61 @@ class ExchangeQueue {
return minBytes_;
}

void addSource() {
void addSourceLocked() {
VELOX_CHECK(!noMoreSources_, "addSource called after noMoreSources");
numSources_++;
}

void noMoreSources() {
noMoreSources_ = true;
checkComplete();
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
noMoreSources_ = true;
promises = checkCompleteLocked();
}
clearPromises(promises);
}

void closeLocked() {
queue_.clear();
clearAllPromises();
void close() {
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
promises = closeLocked();
}
clearPromises(promises);
}

private:
void checkComplete() {
std::vector<ContinuePromise> closeLocked() {
queue_.clear();
return clearAllPromisesLocked();
}

std::vector<ContinuePromise> checkCompleteLocked() {
if (noMoreSources_ && numCompleted_ == numSources_) {
atEnd_ = true;
clearAllPromises();
return clearAllPromisesLocked();
}
return {};
}

void clearAllPromises() {
for (auto& promise : promises_) {
std::vector<ContinuePromise> promises;
{
std::lock_guard<std::mutex> l(mutex_);
promises = clearAllPromisesLocked();
}
clearPromises(promises);
}

std::vector<ContinuePromise> clearAllPromisesLocked() {
return std::move(promises_);
}

static void clearPromises(std::vector<ContinuePromise>& promises) {
for (auto& promise : promises) {
promise.setValue();
}
promises_.clear();
}

int numCompleted_ = 0;
Expand All @@ -217,13 +254,13 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)>;
memory::MemoryPool* pool)>;

ExchangeSource(
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool)
memory::MemoryPool* pool)
: taskId_(taskId),
destination_(destination),
queue_(std::move(queue)),
Expand All @@ -235,7 +272,7 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {
const std::string& taskId,
int destination,
std::shared_ptr<ExchangeQueue> queue,
memory::MemoryPool* FOLLY_NONNULL pool);
memory::MemoryPool* pool);

// Returns true if there is no request to the source pending or if
// this should be retried. If true, the caller is expected to call
Expand Down Expand Up @@ -283,7 +320,7 @@ class ExchangeSource : public std::enable_shared_from_this<ExchangeSource> {
bool atEnd_ = false;

protected:
memory::MemoryPool* FOLLY_NONNULL pool_;
memory::MemoryPool* pool_;
};

struct RemoteConnectorSplit : public connector::ConnectorSplit {
Expand All @@ -301,7 +338,7 @@ class ExchangeClient {

ExchangeClient(
int destination,
memory::MemoryPool* FOLLY_NONNULL pool,
memory::MemoryPool* pool,
int64_t minSize = kDefaultMinSize)
: destination_(destination),
pool_(pool),
Expand All @@ -315,7 +352,7 @@ class ExchangeClient {

~ExchangeClient();

memory::MemoryPool* FOLLY_NULLABLE pool() const {
memory::MemoryPool* pool() const {
return pool_;
}

Expand All @@ -334,15 +371,13 @@ class ExchangeClient {
return queue_;
}

std::unique_ptr<SerializedPage> next(
bool* FOLLY_NONNULL atEnd,
ContinueFuture* FOLLY_NONNULL future);
std::unique_ptr<SerializedPage> next(bool* atEnd, ContinueFuture* future);

std::string toString();

private:
const int destination_;
memory::MemoryPool* const FOLLY_NONNULL pool_;
memory::MemoryPool* const pool_;
std::shared_ptr<ExchangeQueue> queue_;
std::unordered_set<std::string> taskIds_;
std::vector<std::shared_ptr<ExchangeSource>> sources_;
Expand All @@ -353,7 +388,7 @@ class Exchange : public SourceOperator {
public:
Exchange(
int32_t operatorId,
DriverCtx* FOLLY_NONNULL ctx,
DriverCtx* ctx,
const std::shared_ptr<const core::ExchangeNode>& exchangeNode,
std::shared_ptr<ExchangeClient> exchangeClient,
const std::string& operatorType = "Exchange")
Expand Down Expand Up @@ -382,7 +417,7 @@ class Exchange : public SourceOperator {
exchangeClient_ = nullptr;
}

BlockingReason isBlocked(ContinueFuture* FOLLY_NONNULL future) override;
BlockingReason isBlocked(ContinueFuture* future) override;

bool isFinished() override;

Expand All @@ -397,7 +432,7 @@ class Exchange : public SourceOperator {
/// this operator is not the first operator in the pipeline and therefore is
/// not responsible for fetching splits and adding them to the
/// exchangeClient_.
bool getSplits(ContinueFuture* FOLLY_NONNULL future);
bool getSplits(ContinueFuture* future);

const core::PlanNodeId planNodeId_;
bool noMoreSplits_ = false;
Expand Down
10 changes: 6 additions & 4 deletions velox/exec/tests/PartitionedOutputBufferManagerTest.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -367,13 +367,15 @@ TEST_F(PartitionedOutputBufferManagerTest, outOfOrderAcks) {
TEST_F(PartitionedOutputBufferManagerTest, errorInQueue) {
auto queue = std::make_shared<ExchangeQueue>(1 << 20);
auto page = std::make_unique<SerializedPage>(folly::IOBuf::copyBuffer("", 0));
{
std::lock_guard<std::mutex> l(queue->mutex());
queue->setErrorLocked("error");
std::vector<ContinuePromise> promises;
{ queue->setError("error"); }
for (auto& promise : promises) {
promise.setValue();
}
ContinueFuture future;
bool atEnd = false;
EXPECT_THROW(auto page = queue->dequeue(&atEnd, &future), std::runtime_error);
EXPECT_THROW(
auto page = queue->dequeueLocked(&atEnd, &future), std::runtime_error);
}

TEST_F(PartitionedOutputBufferManagerTest, serializedPage) {
Expand Down