From efe03b19d5966a6dd73fed0b2068c7c03bf8a23a Mon Sep 17 00:00:00 2001 From: Jialiang Tan Date: Fri, 14 Nov 2025 13:12:33 -0800 Subject: [PATCH] feat: Add findOrCreateBatchQueryCtx to query ctx manager (#26619) Summary: In Batch mode, only one query is running at a time. When tasks fail during memory arbitration, the query memory pool will be set aborted, failing any successive tasks immediately. Yet one task should not fail other newly admitted tasks because of task retries and server reuse. Failure control among tasks should be independent. So if query memory pool is aborted already, a cache clear is performed to allow successive tasks to create a new query context to continue execution. This change also changes the folly::Synchronized lock to std::mutex for more flexible locking. Reviewed By: xiaoxmeng Differential Revision: D87006158 --- .../presto_cpp/main/QueryContextManager.cpp | 126 +++++++++++++++--- .../presto_cpp/main/QueryContextManager.h | 83 +++--------- .../presto_cpp/main/TaskResource.cpp | 2 +- .../main/tests/QueryContextCacheTest.cpp | 2 +- .../main/tests/QueryContextManagerTest.cpp | 60 ++++++++- 5 files changed, 191 insertions(+), 82 deletions(-) diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp index f31868cfca41d..3598603c300eb 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.cpp +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.cpp @@ -34,6 +34,68 @@ inline QueryId queryIdFromTaskId(const TaskId& taskId) { } // namespace +std::shared_ptr QueryContextCache::get( + const protocol::QueryId& queryId) { + auto iter = queryCtxs_.find(queryId); + if (iter == queryCtxs_.end()) { + return nullptr; + } + + queryIds_.erase(iter->second.idListIterator); + + if (auto queryCtx = iter->second.queryCtx.lock()) { + // Move the queryId to front, if queryCtx is still alive. + queryIds_.push_front(queryId); + iter->second.idListIterator = queryIds_.begin(); + return queryCtx; + } + queryCtxs_.erase(iter); + return nullptr; +} + +std::shared_ptr QueryContextCache::insert( + const protocol::QueryId& queryId, + std::shared_ptr queryCtx) { + if (queryCtxs_.size() >= capacity_) { + evict(); + } + queryIds_.push_front(queryId); + queryCtxs_[queryId] = { + folly::to_weak_ptr(queryCtx), queryIds_.begin(), false}; + return queryCtx; +} + +bool QueryContextCache::hasStartedTasks( + const protocol::QueryId& queryId) const { + auto iter = queryCtxs_.find(queryId); + if (iter != queryCtxs_.end()) { + return iter->second.hasStartedTasks; + } + return false; +} + +void QueryContextCache::setTasksStarted(const protocol::QueryId& queryId) { + auto iter = queryCtxs_.find(queryId); + if (iter != queryCtxs_.end()) { + iter->second.hasStartedTasks = true; + } +} + +void QueryContextCache::evict() { + // Evict least recently used queryCtx if it is not referenced elsewhere. + for (auto victim = queryIds_.end(); victim != queryIds_.begin();) { + --victim; + if (!queryCtxs_[*victim].queryCtx.lock()) { + queryCtxs_.erase(*victim); + queryIds_.erase(victim); + return; + } + } + + // All queries are still inflight. Increase capacity. + capacity_ = std::max(kInitialCapacity, capacity_ * 2); +} + QueryContextManager::QueryContextManager( folly::Executor* driverExecutor, folly::Executor* spillerExecutor) @@ -43,25 +105,58 @@ std::shared_ptr QueryContextManager::findOrCreateQueryCtx( const protocol::TaskId& taskId, const protocol::TaskUpdateRequest& taskUpdateRequest) { - return findOrCreateQueryCtx( + std::lock_guard lock(queryContextCacheMutex_); + return findOrCreateQueryCtxLocked( taskId, toVeloxConfigs( taskUpdateRequest.session, taskUpdateRequest.extraCredentials), toConnectorConfigs(taskUpdateRequest)); } +std::shared_ptr +QueryContextManager::findOrCreateBatchQueryCtx( + const protocol::TaskId& taskId, + const protocol::TaskUpdateRequest& taskUpdateRequest) { + std::lock_guard lock(queryContextCacheMutex_); + auto queryCtx = findOrCreateQueryCtxLocked( + taskId, + toVeloxConfigs( + taskUpdateRequest.session, taskUpdateRequest.extraCredentials), + toConnectorConfigs(taskUpdateRequest)); + if (queryCtx->pool()->aborted()) { + // In Batch mode, only one query is running at a time. When tasks fail + // during memory arbitration, the query memory pool will be set + // aborted, failing any successive tasks immediately. Yet one task + // should not fail other newly admitted tasks because of task retries + // and server reuse. Failure control among tasks should be + // independent. So if query memory pool is aborted already, a cache clear is + // performed to allow successive tasks to create a new query context to + // continue execution. + VELOX_CHECK_EQ(queryContextCache_.size(), 1); + queryContextCache_.clear(); + queryCtx = findOrCreateQueryCtxLocked( + taskId, + toVeloxConfigs( + taskUpdateRequest.session, taskUpdateRequest.extraCredentials), + toConnectorConfigs(taskUpdateRequest)); + } + return queryCtx; +} + bool QueryContextManager::queryHasStartedTasks( const protocol::TaskId& taskId) const { - return queryContextCache_.rlock()->hasStartedTasks(queryIdFromTaskId(taskId)); + std::lock_guard lock(queryContextCacheMutex_); + return queryContextCache_.hasStartedTasks(queryIdFromTaskId(taskId)); } void QueryContextManager::setQueryHasStartedTasks( const protocol::TaskId& taskId) { - queryContextCache_.wlock()->setHasStartedTasks(queryIdFromTaskId(taskId)); + std::lock_guard lock(queryContextCacheMutex_); + queryContextCache_.setTasksStarted(queryIdFromTaskId(taskId)); } -std::shared_ptr QueryContextManager::createAndCacheQueryCtx( - QueryContextCache& cache, +std::shared_ptr +QueryContextManager::createAndCacheQueryCtxLocked( const QueryId& queryId, velox::core::QueryConfig&& queryConfig, std::unordered_map>&& @@ -75,18 +170,17 @@ std::shared_ptr QueryContextManager::createAndCacheQueryCtx( std::move(pool), spillerExecutor_, queryId); - return cache.insert(queryId, std::move(queryCtx)); + return queryContextCache_.insert(queryId, std::move(queryCtx)); } -std::shared_ptr QueryContextManager::findOrCreateQueryCtx( +std::shared_ptr QueryContextManager::findOrCreateQueryCtxLocked( const TaskId& taskId, velox::core::QueryConfig&& queryConfig, std::unordered_map>&& connectorConfigs) { const QueryId queryId{queryIdFromTaskId(taskId)}; - auto lockedCache = queryContextCache_.wlock(); - if (auto queryCtx = lockedCache->get(queryId)) { + if (auto queryCtx = queryContextCache_.get(queryId)) { return queryCtx; } @@ -111,8 +205,7 @@ std::shared_ptr QueryContextManager::findOrCreateQueryCtx( nullptr, poolDbgOpts); - return createAndCacheQueryCtx( - *lockedCache, + return createAndCacheQueryCtxLocked( queryId, std::move(queryConfig), std::move(connectorConfigs), @@ -123,19 +216,20 @@ void QueryContextManager::visitAllContexts( const std::function< void(const protocol::QueryId&, const velox::core::QueryCtx*)>& visitor) const { - auto lockedCache = queryContextCache_.rlock(); - for (const auto& it : lockedCache->ctxs()) { + std::lock_guard lock(queryContextCacheMutex_); + for (const auto& it : queryContextCache_.ctxMap()) { if (const auto queryCtxSP = it.second.queryCtx.lock()) { visitor(it.first, queryCtxSP.get()); } } } -void QueryContextManager::testingClearCache() { - queryContextCache_.wlock()->testingClear(); +void QueryContextManager::clearCache() { + std::lock_guard lock(queryContextCacheMutex_); + queryContextCache_.clear(); } -void QueryContextCache::testingClear() { +void QueryContextCache::clear() { queryCtxs_.clear(); queryIds_.clear(); } diff --git a/presto-native-execution/presto_cpp/main/QueryContextManager.h b/presto-native-execution/presto_cpp/main/QueryContextManager.h index 61ce3b78c203b..3a07b73c36335 100644 --- a/presto-native-execution/presto_cpp/main/QueryContextManager.h +++ b/presto-native-execution/presto_cpp/main/QueryContextManager.h @@ -17,6 +17,7 @@ #include #include #include +#include #include #include "presto_cpp/presto_protocol/core/presto_protocol_core.h" @@ -44,77 +45,31 @@ class QueryContextCache { return queryCtxs_.size(); } - std::shared_ptr get(const protocol::QueryId& queryId) { - auto iter = queryCtxs_.find(queryId); - if (iter != queryCtxs_.end()) { - queryIds_.erase(iter->second.idListIterator); - - if (auto queryCtx = iter->second.queryCtx.lock()) { - // Move the queryId to front, if queryCtx is still alive. - queryIds_.push_front(queryId); - iter->second.idListIterator = queryIds_.begin(); - return queryCtx; - } else { - queryCtxs_.erase(iter); - } - } - return nullptr; + const QueryCtxMap& ctxMap() const { + return queryCtxs_; } + std::shared_ptr get(const protocol::QueryId& queryId); + std::shared_ptr insert( const protocol::QueryId& queryId, - std::shared_ptr queryCtx) { - if (queryCtxs_.size() >= capacity_) { - evict(); - } - queryIds_.push_front(queryId); - queryCtxs_[queryId] = { - folly::to_weak_ptr(queryCtx), queryIds_.begin(), false}; - return queryCtx; - } + std::shared_ptr queryCtx); - bool hasStartedTasks(const protocol::QueryId& queryId) const { - auto iter = queryCtxs_.find(queryId); - if (iter != queryCtxs_.end()) { - return iter->second.hasStartedTasks; - } - return false; - } + bool hasStartedTasks(const protocol::QueryId& queryId) const; - void setHasStartedTasks(const protocol::QueryId& queryId) { - auto iter = queryCtxs_.find(queryId); - if (iter != queryCtxs_.end()) { - iter->second.hasStartedTasks = true; - } - } + void setTasksStarted(const protocol::QueryId& queryId); - void evict() { - // Evict least recently used queryCtx if it is not referenced elsewhere. - for (auto victim = queryIds_.end(); victim != queryIds_.begin();) { - --victim; - if (!queryCtxs_[*victim].queryCtx.lock()) { - queryCtxs_.erase(*victim); - queryIds_.erase(victim); - return; - } - } - - // All queries are still inflight. Increase capacity. - capacity_ = std::max(kInitialCapacity, capacity_ * 2); - } - const QueryCtxMap& ctxs() const { - return queryCtxs_; - } + void evict(); - void testingClear(); + void clear(); private: + static constexpr size_t kInitialCapacity = 256UL; + size_t capacity_; QueryCtxMap queryCtxs_; QueryIdList queryIds_; - - static constexpr size_t kInitialCapacity = 256UL; }; class QueryContextManager { @@ -129,6 +84,10 @@ class QueryContextManager { const protocol::TaskId& taskId, const protocol::TaskUpdateRequest& taskUpdateRequest); + std::shared_ptr findOrCreateBatchQueryCtx( + const protocol::TaskId& taskId, + const protocol::TaskUpdateRequest& taskUpdateRequest); + /// Returns true if the given task's query has at least one task started. bool queryHasStartedTasks(const protocol::TaskId& taskId) const; @@ -142,15 +101,15 @@ class QueryContextManager { visitor) const; /// Test method to clear the query context cache. - void testingClearCache(); + void clearCache(); protected: folly::Executor* const driverExecutor_{nullptr}; folly::Executor* const spillerExecutor_{nullptr}; + QueryContextCache queryContextCache_; private: - virtual std::shared_ptr createAndCacheQueryCtx( - QueryContextCache& cache, + virtual std::shared_ptr createAndCacheQueryCtxLocked( const protocol::QueryId& queryId, velox::core::QueryConfig&& queryConfig, std::unordered_map< @@ -158,14 +117,14 @@ class QueryContextManager { std::shared_ptr>&& connectorConfigs, std::shared_ptr&& pool); - std::shared_ptr findOrCreateQueryCtx( + std::shared_ptr findOrCreateQueryCtxLocked( const protocol::TaskId& taskId, velox::core::QueryConfig&& queryConfig, std::unordered_map< std::string, std::shared_ptr>&& connectorConfigStrings); - folly::Synchronized queryContextCache_; + mutable std::mutex queryContextCacheMutex_; }; } // namespace facebook::presto diff --git a/presto-native-execution/presto_cpp/main/TaskResource.cpp b/presto-native-execution/presto_cpp/main/TaskResource.cpp index 67ba67ebf325d..f0ff2526d79aa 100644 --- a/presto-native-execution/presto_cpp/main/TaskResource.cpp +++ b/presto-native-execution/presto_cpp/main/TaskResource.cpp @@ -338,7 +338,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateBatchTask( } auto queryCtx = - taskManager_.getQueryContextManager()->findOrCreateQueryCtx( + taskManager_.getQueryContextManager()->findOrCreateBatchQueryCtx( taskId, updateRequest); VeloxBatchQueryPlanConverter converter( diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp index b62c225dcfb1f..2f4a2b73b7467 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextCacheTest.cpp @@ -95,7 +95,7 @@ TEST_F(QueryContextCacheTest, hasStartedTasks) { auto queryId = fmt::format("query-{}", i); EXPECT_FALSE(queryContextCache.hasStartedTasks(queryId)); if (i % 2 == 0) { - queryContextCache.setHasStartedTasks(queryId); + queryContextCache.setTasksStarted(queryId); } } diff --git a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp index e9260baa0cc59..c98481ab41651 100644 --- a/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp +++ b/presto-native-execution/presto_cpp/main/tests/QueryContextManagerTest.cpp @@ -248,7 +248,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { {false, false, true}}; for (const auto& testData : testSettings) { SCOPED_TRACE(testData.debugString()); - queryCtxManager->testingClearCache(); + queryCtxManager->clearCache(); auto queryCtx = queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); @@ -258,7 +258,7 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { queryCtx.reset(); } if (testData.clearCache) { - queryCtxManager->testingClearCache(); + queryCtxManager->clearCache(); } auto newQueryCtx = queryCtxManager->findOrCreateQueryCtx(fakeTaskId, fakeUpdateRequest); @@ -271,4 +271,60 @@ TEST_F(QueryContextManagerTest, duplicateQueryRootPoolName) { } } } + +TEST_F(QueryContextManagerTest, findOrCreateBatchQueryCtx) { + protocol::TaskId taskId = "batch.0.0.1.0"; + protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; + auto* queryCtxManager = taskManager_->getQueryContextManager(); + + queryCtxManager->clearCache(); + auto queryCtx = + queryCtxManager->findOrCreateBatchQueryCtx(taskId, updateRequest); + ASSERT_NE(queryCtx, nullptr); + ASSERT_FALSE(queryCtx->pool()->aborted()); + + const auto firstPoolName = queryCtx->pool()->name(); + ASSERT_THAT(firstPoolName, testing::HasSubstr("batch_")); + + auto sameQueryCtx = + queryCtxManager->findOrCreateBatchQueryCtx(taskId, updateRequest); + ASSERT_EQ(queryCtx, sameQueryCtx); + ASSERT_EQ(firstPoolName, sameQueryCtx->pool()->name()); +} + +TEST_F(QueryContextManagerTest, findOrCreateBatchQueryCtxWithAbortedPool) { + protocol::TaskId taskId = "batch.0.0.1.0"; + protocol::SessionRepresentation session{.systemProperties = {}}; + protocol::TaskUpdateRequest updateRequest; + updateRequest.session = session; + auto* queryCtxManager = taskManager_->getQueryContextManager(); + + queryCtxManager->clearCache(); + auto queryCtx = + queryCtxManager->findOrCreateBatchQueryCtx(taskId, updateRequest); + ASSERT_NE(queryCtx, nullptr); + ASSERT_FALSE(queryCtx->pool()->aborted()); + + const auto firstPoolName = queryCtx->pool()->name(); + ASSERT_THAT(firstPoolName, testing::HasSubstr("batch_")); + + try { + VELOX_FAIL("Test abortion"); + } catch (...) { + queryCtx->pool()->abort(std::current_exception()); + } + ASSERT_TRUE(queryCtx->pool()->aborted()); + + auto newQueryCtx = + queryCtxManager->findOrCreateBatchQueryCtx(taskId, updateRequest); + ASSERT_NE(newQueryCtx, nullptr); + ASSERT_NE(queryCtx, newQueryCtx); + ASSERT_FALSE(newQueryCtx->pool()->aborted()); + + const auto newPoolName = newQueryCtx->pool()->name(); + ASSERT_THAT(newPoolName, testing::HasSubstr("batch_")); + ASSERT_NE(firstPoolName, newPoolName); +} } // namespace facebook::presto