Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 16 additions & 0 deletions presto-native-execution/presto_cpp/main/PrestoTask.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#pragma once

#include <memory>
#include "presto_cpp/main/http/HttpServer.h"
#include "presto_cpp/main/types/PrestoTaskId.h"
#include "presto_cpp/presto_protocol/presto_protocol.h"
#include "velox/exec/Task.h"
Expand Down Expand Up @@ -59,10 +60,25 @@ struct Result {

struct ResultRequest {
PromiseHolderWeakPtr<std::unique_ptr<Result>> promise;
std::weak_ptr<http::CallbackRequestHandlerState> state;
protocol::TaskId taskId;
int64_t bufferId;
int64_t token;
protocol::DataSize maxSize;

ResultRequest(
PromiseHolderWeakPtr<std::unique_ptr<Result>> _promise,
std::weak_ptr<http::CallbackRequestHandlerState> _state,
protocol::TaskId _taskId,
int64_t _bufferId,
int64_t _token,
protocol::DataSize _maxSize)
: promise(std::move(_promise)),
state(std::move(_state)),
taskId(_taskId),
bufferId(_bufferId),
token(_token),
maxSize(_maxSize) {}
};

struct PrestoTask {
Expand Down
23 changes: 17 additions & 6 deletions presto-native-execution/presto_cpp/main/TaskManager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ std::unique_ptr<Result> createCompleteResult(long token) {

void getData(
PromiseHolderPtr<std::unique_ptr<Result>> promiseHolder,
std::weak_ptr<http::CallbackRequestHandlerState> stateHolder,
const TaskId& taskId,
long destination,
long token,
Expand Down Expand Up @@ -154,6 +155,13 @@ void getData(
RECORD_METRIC_VALUE(
kCounterPartitionedOutputBufferGetDataLatencyMs,
getCurrentTimeMs() - startMs);
},
[stateHolder]() {
auto state = stateHolder.lock();
if (state == nullptr) {
return false;
}
return !state->requestExpired();
});

if (!bufferFound) {
Expand Down Expand Up @@ -382,6 +390,7 @@ void TaskManager::getDataForResultRequests(
<< ", sequence " << resultRequest->token;
getData(
resultRequest->promise.lock(),
resultRequest->state,
resultRequest->taskId,
resultRequest->bufferId,
resultRequest->token,
Expand Down Expand Up @@ -831,6 +840,7 @@ folly::Future<std::unique_ptr<Result>> TaskManager::getResults(
if (prestoTask->task->state() == exec::kRunning) {
getData(
promiseHolder,
folly::to_weak_ptr(state),
taskId,
destination,
token,
Expand All @@ -852,12 +862,13 @@ folly::Future<std::unique_ptr<Result>> TaskManager::getResults(

keepPromiseAlive(promiseHolder, state);

auto request = std::make_unique<ResultRequest>();
request->promise = folly::to_weak_ptr(promiseHolder);
request->taskId = taskId;
request->bufferId = destination;
request->token = token;
request->maxSize = maxSize;
auto request = std::make_unique<ResultRequest>(
folly::to_weak_ptr(promiseHolder),
folly::to_weak_ptr(state),
taskId,
destination,
token,
maxSize);
prestoTask->resultRequests.insert({destination, std::move(request)});
return std::move(future)
.via(httpSrvCpuExecutor_)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,11 @@ void verifyQueryCtxCache(
} // namespace

class QueryContextCacheTest : public testing::Test {
protected:
static void SetUpTestCase() {
memory::MemoryManager::testingSetInstance({});
}

void SetUp() override {
FLAGS_velox_memory_leak_check_enabled = true;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include <folly/executors/ThreadedExecutor.h>
#include <gmock/gmock.h>
#include <gtest/gtest.h>
#include "folly/experimental/EventCount.h"
#include "presto_cpp/main/PrestoExchangeSource.h"
#include "presto_cpp/main/TaskResource.h"
#include "presto_cpp/main/tests/HttpServerWrapper.h"
Expand All @@ -27,6 +28,7 @@
#include "velox/dwio/common/WriterFactory.h"
#include "velox/dwio/common/tests/utils/BatchMaker.h"
#include "velox/exec/Exchange.h"
#include "velox/exec/Values.h"
#include "velox/exec/tests/utils/PlanBuilder.h"
#include "velox/exec/tests/utils/QueryAssertions.h"
#include "velox/exec/tests/utils/TempDirectoryPath.h"
Expand Down Expand Up @@ -146,6 +148,7 @@ class TaskManagerTest : public testing::Test {
public:
static void SetUpTestCase() {
memory::MemoryManager::testingSetInstance({});
common::testutil::TestValue::enable();
}

protected:
Expand Down Expand Up @@ -709,6 +712,70 @@ TEST_F(TaskManagerTest, fecthFromFinishedTask) {
ASSERT_TRUE(newResult.value()->complete);
}

DEBUG_ONLY_TEST_F(TaskManagerTest, fecthFromArbitraryOutput) {
// Block output until the first fetch destination becomes inactive.
folly::EventCount outputWait;
std::atomic<bool> outputWaitFlag{false};
SCOPED_TESTVALUE_SET(
"facebook::velox::exec::Values::getOutput",
std::function<void(const velox::exec::Values*)>(
[&](const velox::exec::Values* values) {
outputWait.await([&]() { return outputWaitFlag.load(); });
}));

const std::vector<RowVectorPtr> batches = makeVectors(1, 1'000);
auto planFragment = exec::test::PlanBuilder()
.values(batches)
.partitionedOutputArbitrary({"c0", "c1"})
.planFragment();
const protocol::TaskId taskId = "source.0.0.1.0";
const auto taskInfo = createOrUpdateTask(taskId, {}, planFragment);

const protocol::Duration longWait("10s");
const auto maxSize = protocol::DataSize("1024MB");
auto expiredRequestState = http::CallbackRequestHandlerState::create();
auto consumeCompleted = false;
// Consume from destination 0 to simulate the case that the http request has
// expired while destination has notify setup.
auto expiredResultWait = taskManager_->getResults(
taskId, 0, 0, maxSize, protocol::Duration("1s"), expiredRequestState);
// Reset the http request to simulate the case that it has expired.
expiredRequestState.reset();

// Unblock output.
outputWaitFlag = true;
outputWait.notifyAll();

// Consuming from destination 1 and expect get result.
auto requestState = http::CallbackRequestHandlerState::create();
const auto result =
taskManager_
->getResults(
taskId, 1, 0, maxSize, protocol::Duration("10s"), requestState)
.getVia(folly::EventBaseManager::get()->getEventBase());
ASSERT_FALSE(result->complete);
ASSERT_FALSE(result->data->empty());
ASSERT_EQ(result->sequence, 0);
ASSERT_EQ(result->nextSequence, 1);

// Check the expired result hasn't fetched any data after timeout.
const auto expriredResult =
std::move(expiredResultWait)
.getVia(folly::EventBaseManager::get()->getEventBase());
ASSERT_FALSE(expriredResult->complete);
ASSERT_TRUE(expriredResult->data->empty());
ASSERT_EQ(expriredResult->sequence, 0);
ASSERT_EQ(expriredResult->nextSequence, 0);

// Close destinations and triggers the task closure.
taskManager_->abortResults(taskId, 0);
taskManager_->abortResults(taskId, 1);

auto prestoTask = taskManager_->tasks().at(taskId);
ASSERT_TRUE(waitForTaskStateChange(
prestoTask->task.get(), TaskState::kFinished, 3'000'000));
}

TEST_F(TaskManagerTest, taskCleanupWithPendingResultData) {
// Trigger old task cleanup immediately.
taskManager_->setOldTaskCleanUpMs(0);
Expand Down Expand Up @@ -1103,12 +1170,14 @@ TEST_F(TaskManagerTest, getDataOnAbortedTask) {
});
auto promiseHolder = std::make_shared<PromiseHolder<std::unique_ptr<Result>>>(
std::move(promise));
auto request = std::make_unique<ResultRequest>();
request->promise = folly::to_weak_ptr(promiseHolder);
request->taskId = scanTaskId;
request->bufferId = 0;
request->token = token;
request->maxSize = protocol::DataSize("32MB");
auto requestState = http::CallbackRequestHandlerState::create();
auto request = std::make_unique<ResultRequest>(
folly::to_weak_ptr(promiseHolder),
folly::to_weak_ptr(requestState),
scanTaskId,
0,
token,
protocol::DataSize("32MB"));
prestoTask->resultRequests.insert({0, std::move(request)});
prestoTask->task = createDummyExecTask(scanTaskId, planFragment);
taskManager_->getDataForResultRequests(prestoTask->resultRequests);
Expand Down