diff --git a/presto-native-execution/presto_cpp/main/PrestoTask.h b/presto-native-execution/presto_cpp/main/PrestoTask.h index e159bdf2247ba..96b3d84b9b420 100644 --- a/presto-native-execution/presto_cpp/main/PrestoTask.h +++ b/presto-native-execution/presto_cpp/main/PrestoTask.h @@ -14,6 +14,7 @@ #pragma once #include +#include "presto_cpp/main/http/HttpServer.h" #include "presto_cpp/main/types/PrestoTaskId.h" #include "presto_cpp/presto_protocol/presto_protocol.h" #include "velox/exec/Task.h" @@ -59,10 +60,25 @@ struct Result { struct ResultRequest { PromiseHolderWeakPtr> promise; + std::weak_ptr state; protocol::TaskId taskId; int64_t bufferId; int64_t token; protocol::DataSize maxSize; + + ResultRequest( + PromiseHolderWeakPtr> _promise, + std::weak_ptr _state, + protocol::TaskId _taskId, + int64_t _bufferId, + int64_t _token, + protocol::DataSize _maxSize) + : promise(std::move(_promise)), + state(std::move(_state)), + taskId(_taskId), + bufferId(_bufferId), + token(_token), + maxSize(_maxSize) {} }; struct PrestoTask { diff --git a/presto-native-execution/presto_cpp/main/TaskManager.cpp b/presto-native-execution/presto_cpp/main/TaskManager.cpp index 21dbbcfbbb856..7481a77db0055 100644 --- a/presto-native-execution/presto_cpp/main/TaskManager.cpp +++ b/presto-native-execution/presto_cpp/main/TaskManager.cpp @@ -98,6 +98,7 @@ std::unique_ptr createCompleteResult(long token) { void getData( PromiseHolderPtr> promiseHolder, + std::weak_ptr stateHolder, const TaskId& taskId, long destination, long token, @@ -154,6 +155,13 @@ void getData( RECORD_METRIC_VALUE( kCounterPartitionedOutputBufferGetDataLatencyMs, getCurrentTimeMs() - startMs); + }, + [stateHolder]() { + auto state = stateHolder.lock(); + if (state == nullptr) { + return false; + } + return !state->requestExpired(); }); if (!bufferFound) { @@ -382,6 +390,7 @@ void TaskManager::getDataForResultRequests( << ", sequence " << resultRequest->token; getData( resultRequest->promise.lock(), + resultRequest->state, resultRequest->taskId, resultRequest->bufferId, resultRequest->token, @@ -831,6 +840,7 @@ folly::Future> TaskManager::getResults( if (prestoTask->task->state() == exec::kRunning) { getData( promiseHolder, + folly::to_weak_ptr(state), taskId, destination, token, @@ -852,12 +862,13 @@ folly::Future> TaskManager::getResults( keepPromiseAlive(promiseHolder, state); - auto request = std::make_unique(); - request->promise = folly::to_weak_ptr(promiseHolder); - request->taskId = taskId; - request->bufferId = destination; - request->token = token; - request->maxSize = maxSize; + auto request = std::make_unique( + folly::to_weak_ptr(promiseHolder), + folly::to_weak_ptr(state), + taskId, + destination, + token, + maxSize); prestoTask->resultRequests.insert({destination, std::move(request)}); return std::move(future) .via(httpSrvCpuExecutor_) diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp index 0edd4ed845ee4..216402924f49e 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp @@ -35,6 +35,11 @@ void verifyQueryCtxCache( } // namespace class QueryContextCacheTest : public testing::Test { + protected: + static void SetUpTestCase() { + memory::MemoryManager::testingSetInstance({}); + } + void SetUp() override { FLAGS_velox_memory_leak_check_enabled = true; } diff --git a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp index 02af7db62a2b0..970937f9f3aee 100644 --- a/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/TaskManagerTest.cpp @@ -15,6 +15,7 @@ #include #include #include +#include "folly/experimental/EventCount.h" #include "presto_cpp/main/PrestoExchangeSource.h" #include "presto_cpp/main/TaskResource.h" #include "presto_cpp/main/tests/HttpServerWrapper.h" @@ -27,6 +28,7 @@ #include "velox/dwio/common/WriterFactory.h" #include "velox/dwio/common/tests/utils/BatchMaker.h" #include "velox/exec/Exchange.h" +#include "velox/exec/Values.h" #include "velox/exec/tests/utils/PlanBuilder.h" #include "velox/exec/tests/utils/QueryAssertions.h" #include "velox/exec/tests/utils/TempDirectoryPath.h" @@ -146,6 +148,7 @@ class TaskManagerTest : public testing::Test { public: static void SetUpTestCase() { memory::MemoryManager::testingSetInstance({}); + common::testutil::TestValue::enable(); } protected: @@ -709,6 +712,70 @@ TEST_F(TaskManagerTest, fecthFromFinishedTask) { ASSERT_TRUE(newResult.value()->complete); } +DEBUG_ONLY_TEST_F(TaskManagerTest, fecthFromArbitraryOutput) { + // Block output until the first fetch destination becomes inactive. + folly::EventCount outputWait; + std::atomic outputWaitFlag{false}; + SCOPED_TESTVALUE_SET( + "facebook::velox::exec::Values::getOutput", + std::function( + [&](const velox::exec::Values* values) { + outputWait.await([&]() { return outputWaitFlag.load(); }); + })); + + const std::vector batches = makeVectors(1, 1'000); + auto planFragment = exec::test::PlanBuilder() + .values(batches) + .partitionedOutputArbitrary({"c0", "c1"}) + .planFragment(); + const protocol::TaskId taskId = "source.0.0.1.0"; + const auto taskInfo = createOrUpdateTask(taskId, {}, planFragment); + + const protocol::Duration longWait("10s"); + const auto maxSize = protocol::DataSize("1024MB"); + auto expiredRequestState = http::CallbackRequestHandlerState::create(); + auto consumeCompleted = false; + // Consume from destination 0 to simulate the case that the http request has + // expired while destination has notify setup. + auto expiredResultWait = taskManager_->getResults( + taskId, 0, 0, maxSize, protocol::Duration("1s"), expiredRequestState); + // Reset the http request to simulate the case that it has expired. + expiredRequestState.reset(); + + // Unblock output. + outputWaitFlag = true; + outputWait.notifyAll(); + + // Consuming from destination 1 and expect get result. + auto requestState = http::CallbackRequestHandlerState::create(); + const auto result = + taskManager_ + ->getResults( + taskId, 1, 0, maxSize, protocol::Duration("10s"), requestState) + .getVia(folly::EventBaseManager::get()->getEventBase()); + ASSERT_FALSE(result->complete); + ASSERT_FALSE(result->data->empty()); + ASSERT_EQ(result->sequence, 0); + ASSERT_EQ(result->nextSequence, 1); + + // Check the expired result hasn't fetched any data after timeout. + const auto expriredResult = + std::move(expiredResultWait) + .getVia(folly::EventBaseManager::get()->getEventBase()); + ASSERT_FALSE(expriredResult->complete); + ASSERT_TRUE(expriredResult->data->empty()); + ASSERT_EQ(expriredResult->sequence, 0); + ASSERT_EQ(expriredResult->nextSequence, 0); + + // Close destinations and triggers the task closure. + taskManager_->abortResults(taskId, 0); + taskManager_->abortResults(taskId, 1); + + auto prestoTask = taskManager_->tasks().at(taskId); + ASSERT_TRUE(waitForTaskStateChange( + prestoTask->task.get(), TaskState::kFinished, 3'000'000)); +} + TEST_F(TaskManagerTest, taskCleanupWithPendingResultData) { // Trigger old task cleanup immediately. taskManager_->setOldTaskCleanUpMs(0); @@ -1103,12 +1170,14 @@ TEST_F(TaskManagerTest, getDataOnAbortedTask) { }); auto promiseHolder = std::make_shared>>( std::move(promise)); - auto request = std::make_unique(); - request->promise = folly::to_weak_ptr(promiseHolder); - request->taskId = scanTaskId; - request->bufferId = 0; - request->token = token; - request->maxSize = protocol::DataSize("32MB"); + auto requestState = http::CallbackRequestHandlerState::create(); + auto request = std::make_unique( + folly::to_weak_ptr(promiseHolder), + folly::to_weak_ptr(requestState), + scanTaskId, + 0, + token, + protocol::DataSize("32MB")); prestoTask->resultRequests.insert({0, std::move(request)}); prestoTask->task = createDummyExecTask(scanTaskId, planFragment); taskManager_->getDataForResultRequests(prestoTask->resultRequests); diff --git a/presto-native-execution/velox b/presto-native-execution/velox index 0f96c1c26562f..26b8ca559de5f 160000 --- a/presto-native-execution/velox +++ b/presto-native-execution/velox @@ -1 +1 @@ -Subproject commit 0f96c1c26562f8702e58edd2f54c6503225beff8 +Subproject commit 26b8ca559de5f459cc65c37cdff2c52507aff8c9