diff --git a/presto-native-execution/presto_cpp/main/TaskResource.cpp b/presto-native-execution/presto_cpp/main/TaskResource.cpp index 736fc489ea062..b3c41bd0d3463 100644 --- a/presto-native-execution/presto_cpp/main/TaskResource.cpp +++ b/presto-native-execution/presto_cpp/main/TaskResource.cpp @@ -55,6 +55,59 @@ std::optional getMaxWait(proxygen::HTTPMessage* message) { return protocol::Duration( headers.getSingleOrEmpty(protocol::PRESTO_MAX_WAIT_HTTP_HEADER)); } + +bool shouldUseThrift(const proxygen::HTTPMessage& message) { + const auto& acceptHeader = + message.getHeaders().getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); + return acceptHeader.find(http::kMimeTypeApplicationThrift) != + std::string::npos; +} + +template +void sendPrestoResponse( + proxygen::ResponseHandler* downstream, + const T& data, + bool sendThrift) { + if (sendThrift) { + ThriftT thriftData; + toThrift(data, thriftData); + http::sendOkThriftResponse(downstream, thriftWrite(thriftData)); + } else { + http::sendOkResponse(downstream, json(data)); + } +} + +/// Creates a CallbackRequestHandler that executes a void work function on the +/// given executor, then sends an empty OK response. On exception, sends an +/// error response. Used for simple fire-and-forget handlers. +template +proxygen::RequestHandler* executeAndRespond( + folly::Executor* executor, + WorkFn&& workFn) { + return new http::CallbackRequestHandler( + [executor, work = std::forward(workFn)]( + proxygen::HTTPMessage* /*message*/, + const std::vector>& /*body*/, + proxygen::ResponseHandler* downstream, + std::shared_ptr handlerState) { + folly::via(executor, std::move(work)) + .via( + folly::getKeepAliveToken( + folly::EventBaseManager::get()->getEventBase())) + .thenValue([downstream, handlerState](auto&& /* unused */) { + if (!handlerState->requestExpired()) { + http::sendOkResponse(downstream); + } + }) + .thenError( + folly::tag_t{}, + [downstream, handlerState](auto&& e) { + if (!handlerState->requestExpired()) { + http::sendErrorResponse(downstream, e.what()); + } + }); + }); +} } // namespace void TaskResource::registerUris(http::HttpServer& server) { @@ -136,34 +189,9 @@ proxygen::RequestHandler* TaskResource::abortResults( const std::vector& pathMatch) { protocol::TaskId taskId = pathMatch[1]; long destination = folly::to(pathMatch[2]); - return new http::CallbackRequestHandler( - [this, taskId, destination]( - proxygen::HTTPMessage* /*message*/, - const std::vector>& /*body*/, - proxygen::ResponseHandler* downstream, - std::shared_ptr handlerState) { - folly::via( - httpSrvCpuExecutor_, - [this, taskId, destination, handlerState]() { - taskManager_.abortResults(taskId, destination); - return true; - }) - .via( - folly::getKeepAliveToken( - folly::EventBaseManager::get()->getEventBase())) - .thenValue([downstream, handlerState](auto&& /* unused */) { - if (!handlerState->requestExpired()) { - http::sendOkResponse(downstream); - } - }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](auto&& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }); - }); + return executeAndRespond(httpSrvCpuExecutor_, [this, taskId, destination]() { + taskManager_.abortResults(taskId, destination); + }); } proxygen::RequestHandler* TaskResource::acknowledgeResults( @@ -172,34 +200,9 @@ proxygen::RequestHandler* TaskResource::acknowledgeResults( protocol::TaskId taskId = pathMatch[1]; long bufferId = folly::to(pathMatch[2]); long token = folly::to(pathMatch[3]); - - return new http::CallbackRequestHandler( - [this, taskId, bufferId, token]( - proxygen::HTTPMessage* /*message*/, - const std::vector>& /*body*/, - proxygen::ResponseHandler* downstream, - std::shared_ptr handlerState) { - folly::via( - httpSrvCpuExecutor_, - [this, taskId, bufferId, token]() { - taskManager_.acknowledgeResults(taskId, bufferId, token); - return true; - }) - .via( - folly::getKeepAliveToken( - folly::EventBaseManager::get()->getEventBase())) - .thenValue([downstream, handlerState](auto&& /* unused */) { - if (!handlerState->requestExpired()) { - http::sendOkResponse(downstream); - } - }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](auto&& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }); + return executeAndRespond( + httpSrvCpuExecutor_, [this, taskId, bufferId, token]() { + taskManager_.acknowledgeResults(taskId, bufferId, token); }); } @@ -216,10 +219,7 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl( bool summarize = message->hasQueryParam("summarize"); const auto& headers = message->getHeaders(); - const auto& acceptHeader = - headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); - const auto sendThrift = - acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + const auto sendThrift = shouldUseThrift(*message); const auto& contentHeader = headers.getSingleOrEmpty(proxygen::HTTP_HEADER_CONTENT_TYPE); const auto receiveThrift = @@ -282,14 +282,8 @@ proxygen::RequestHandler* TaskResource::createOrUpdateTaskImpl( folly::EventBaseManager::get()->getEventBase())) .thenValue([downstream, handlerState, sendThrift](auto taskInfo) { if (!handlerState->requestExpired()) { - if (sendThrift) { - thrift::TaskInfo thriftTaskInfo; - toThrift(*taskInfo, thriftTaskInfo); - http::sendOkThriftResponse( - downstream, thriftWrite(thriftTaskInfo)); - } else { - http::sendOkResponse(downstream, json(*taskInfo)); - } + sendPrestoResponse( + downstream, *taskInfo, sendThrift); } }) .thenError( @@ -419,11 +413,7 @@ proxygen::RequestHandler* TaskResource::deleteTask( message->getQueryParam(protocol::PRESTO_ABORT_TASK_URL_PARAM) == "true"; } bool summarize = message->hasQueryParam("summarize"); - const auto& headers = message->getHeaders(); - const auto& acceptHeader = - headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); - const auto sendThrift = - acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + const auto sendThrift = shouldUseThrift(*message); return new http::CallbackRequestHandler( [this, taskId, abort, summarize, sendThrift]( @@ -448,14 +438,8 @@ proxygen::RequestHandler* TaskResource::deleteTask( sendTaskNotFound(downstream, taskId); return; } - if (sendThrift) { - thrift::TaskInfo thriftTaskInfo; - toThrift(*taskInfo, thriftTaskInfo); - http::sendOkThriftResponse( - downstream, thriftWrite(thriftTaskInfo)); - } else { - http::sendOkResponse(downstream, json(*taskInfo)); - } + sendPrestoResponse( + downstream, *taskInfo, sendThrift); } }) .thenError( @@ -565,12 +549,7 @@ proxygen::RequestHandler* TaskResource::getTaskStatus( protocol::TaskId taskId = pathMatch[1]; auto currentState = getCurrentState(message); auto maxWait = getMaxWait(message); - - const auto& headers = message->getHeaders(); - const auto& acceptHeader = - headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); - const auto sendThrift = - acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + const auto sendThrift = shouldUseThrift(*message); return new http::CallbackRequestHandler( [this, sendThrift, taskId, currentState, maxWait]( @@ -596,15 +575,10 @@ proxygen::RequestHandler* TaskResource::getTaskStatus( [sendThrift, downstream, taskId, handlerState]( std::unique_ptr taskStatus) { if (!handlerState->requestExpired()) { - if (sendThrift) { - thrift::TaskStatus thriftTaskStatus; - toThrift(*taskStatus, thriftTaskStatus); - http::sendOkThriftResponse( - downstream, thriftWrite(thriftTaskStatus)); - } else { - json taskStatusJson = *taskStatus; - http::sendOkResponse(downstream, taskStatusJson); - } + sendPrestoResponse< + protocol::TaskStatus, + thrift::TaskStatus>( + downstream, *taskStatus, sendThrift); } }) .thenError( @@ -629,12 +603,7 @@ proxygen::RequestHandler* TaskResource::getTaskInfo( auto currentState = getCurrentState(message); auto maxWait = getMaxWait(message); bool summarize = message->hasQueryParam("summarize"); - - const auto& headers = message->getHeaders(); - const auto& acceptHeader = - headers.getSingleOrEmpty(proxygen::HTTP_HEADER_ACCEPT); - const auto sendThrift = - acceptHeader.find(http::kMimeTypeApplicationThrift) != std::string::npos; + const auto sendThrift = shouldUseThrift(*message); return new http::CallbackRequestHandler( [this, taskId, currentState, maxWait, summarize, sendThrift]( @@ -661,14 +630,8 @@ proxygen::RequestHandler* TaskResource::getTaskInfo( .thenValue([downstream, taskId, handlerState, sendThrift]( std::unique_ptr taskInfo) { if (!handlerState->requestExpired()) { - if (sendThrift) { - thrift::TaskInfo thriftTaskInfo; - toThrift(*taskInfo, thriftTaskInfo); - http::sendOkThriftResponse( - downstream, thriftWrite(thriftTaskInfo)); - } else { - http::sendOkResponse(downstream, json(*taskInfo)); - } + sendPrestoResponse( + downstream, *taskInfo, sendThrift); } }) .thenError( @@ -690,33 +653,8 @@ proxygen::RequestHandler* TaskResource::removeRemoteSource( const std::vector& pathMatch) { protocol::TaskId taskId = pathMatch[1]; auto remoteId = pathMatch[2]; - - return new http::CallbackRequestHandler( - [this, taskId, remoteId]( - proxygen::HTTPMessage* /*message*/, - const std::vector>& /*body*/, - proxygen::ResponseHandler* downstream, - std::shared_ptr handlerState) { - folly::via( - httpSrvCpuExecutor_, - [this, taskId, remoteId, downstream]() { - taskManager_.removeRemoteSource(taskId, remoteId); - }) - .via( - folly::getKeepAliveToken( - folly::EventBaseManager::get()->getEventBase())) - .thenValue([downstream, handlerState](auto&& /* unused */) { - if (!handlerState->requestExpired()) { - http::sendOkResponse(downstream); - } - }) - .thenError( - folly::tag_t{}, - [downstream, handlerState](const std::exception& e) { - if (!handlerState->requestExpired()) { - http::sendErrorResponse(downstream, e.what()); - } - }); - }); + return executeAndRespond(httpSrvCpuExecutor_, [this, taskId, remoteId]() { + taskManager_.removeRemoteSource(taskId, remoteId); + }); } } // namespace facebook::presto