diff --git a/dbms/src/Flash/BatchCommandsHandler.cpp b/dbms/src/Flash/BatchCommandsHandler.cpp index e6768f03c13..c56a8dafad7 100644 --- a/dbms/src/Flash/BatchCommandsHandler.cpp +++ b/dbms/src/Flash/BatchCommandsHandler.cpp @@ -1,6 +1,5 @@ #include #include -#include namespace DB { @@ -10,6 +9,33 @@ BatchCommandsHandler::BatchCommandsHandler(BatchCommandsContext & batch_commands : batch_commands_context(batch_commands_context_), request(request_), response(response_), log(&Logger::get("BatchCommandsHandler")) {} +ThreadPool::Job BatchCommandsHandler::handleCommandJob( + const tikvpb::BatchCommandsRequest::Request & req, tikvpb::BatchCommandsResponse::Response & resp, grpc::Status & ret) const +{ + return [&]() { + if (!req.has_coprocessor()) + { + ret = grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); + return; + } + + const auto & cop_req = req.coprocessor(); + auto cop_resp = resp.mutable_coprocessor(); + + auto [context, status] = batch_commands_context.db_context_creation_func(&batch_commands_context.grpc_server_context); + if (!status.ok()) + { + ret = status; + return; + } + + CoprocessorContext cop_context(context, cop_req.context(), batch_commands_context.grpc_server_context); + CoprocessorHandler cop_handler(cop_context, &cop_req, cop_resp); + + ret = cop_handler.execute(); + }; +} + grpc::Status BatchCommandsHandler::execute() { if (request.requests_size() == 0) @@ -17,31 +43,6 @@ grpc::Status BatchCommandsHandler::execute() // TODO: Fill transport_layer_load into BatchCommandsResponse. - auto command_handler_func - = [](BatchCommandsContext::DBContextCreationFunc db_context_creation_func, grpc::ServerContext * grpc_server_context, - const tikvpb::BatchCommandsRequest::Request & req, tikvpb::BatchCommandsResponse::Response & resp, grpc::Status & ret) { - if (!req.has_coprocessor()) - { - ret = grpc::Status(::grpc::StatusCode::UNIMPLEMENTED, ""); - return; - } - - const auto & cop_req = req.coprocessor(); - auto cop_resp = resp.mutable_coprocessor(); - - auto [context, status] = db_context_creation_func(grpc_server_context); - if (!status.ok()) - { - ret = status; - return; - } - - CoprocessorContext cop_context(context, cop_req.context(), *grpc_server_context); - CoprocessorHandler cop_handler(cop_context, &cop_req, cop_resp); - - ret = cop_handler.execute(); - }; - /// Shortcut for only one request by not going to thread pool. if (request.requests_size() == 1) { @@ -51,7 +52,7 @@ grpc::Status BatchCommandsHandler::execute() auto resp = response.add_responses(); response.add_request_ids(request.request_ids(0)); auto ret = grpc::Status::OK; - command_handler_func(batch_commands_context.db_context_creation_func, &batch_commands_context.grpc_server_context, req, *resp, ret); + handleCommandJob(req, *resp, ret)(); return ret; } @@ -65,7 +66,7 @@ grpc::Status BatchCommandsHandler::execute() ThreadPool thread_pool(max_threads); - std::vector rets; + std::vector rets(request.requests_size()); size_t i = 0; for (const auto & req : request.requests()) @@ -73,10 +74,8 @@ grpc::Status BatchCommandsHandler::execute() auto resp = response.add_responses(); response.add_request_ids(request.request_ids(i++)); rets.emplace_back(grpc::Status::OK); - thread_pool.schedule([&]() { - command_handler_func( - batch_commands_context.db_context_creation_func, &batch_commands_context.grpc_server_context, req, *resp, rets.back()); - }); + + thread_pool.schedule(handleCommandJob(req, *resp, rets.back())); } thread_pool.wait(); @@ -85,7 +84,10 @@ grpc::Status BatchCommandsHandler::execute() for (const auto & ret : rets) { if (!ret.ok()) + { + response.Clear(); return ret; + } } return grpc::Status::OK; diff --git a/dbms/src/Flash/BatchCommandsHandler.h b/dbms/src/Flash/BatchCommandsHandler.h index 800318be39b..55b07a628fd 100644 --- a/dbms/src/Flash/BatchCommandsHandler.h +++ b/dbms/src/Flash/BatchCommandsHandler.h @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #pragma GCC diagnostic push @@ -18,10 +19,10 @@ struct BatchCommandsContext /// Context creation function for each individual command - they should be handled isolated, /// given that context is being used to pass arguments regarding queries. - using DBContextCreationFunc = std::function(grpc::ServerContext *)>; + using DBContextCreationFunc = std::function(const grpc::ServerContext *)>; DBContextCreationFunc db_context_creation_func; - grpc::ServerContext & grpc_server_context; + const grpc::ServerContext & grpc_server_context; BatchCommandsContext( Context & db_context_, DBContextCreationFunc && db_context_creation_func_, grpc::ServerContext & grpc_server_context_) @@ -40,7 +41,11 @@ class BatchCommandsHandler grpc::Status execute(); protected: - BatchCommandsContext & batch_commands_context; + ThreadPool::Job handleCommandJob( + const tikvpb::BatchCommandsRequest::Request & req, tikvpb::BatchCommandsResponse::Response & resp, grpc::Status & ret) const; + +protected: + const BatchCommandsContext & batch_commands_context; const tikvpb::BatchCommandsRequest & request; tikvpb::BatchCommandsResponse & response; diff --git a/dbms/src/Flash/CoprocessorHandler.cpp b/dbms/src/Flash/CoprocessorHandler.cpp index bed9a27624e..faeef0d11af 100644 --- a/dbms/src/Flash/CoprocessorHandler.cpp +++ b/dbms/src/Flash/CoprocessorHandler.cpp @@ -46,8 +46,8 @@ try cop_context.kv_context.region_epoch().version(), cop_context.kv_context.region_epoch().conf_ver(), std::move(key_ranges), dag_response); driver.execute(); - LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handle DAG request done"); cop_response->set_data(dag_response.SerializeAsString()); + LOG_DEBUG(log, __PRETTY_FUNCTION__ << ": Handle DAG request done"); break; } case COP_REQ_TYPE_ANALYZE: diff --git a/dbms/src/Flash/CoprocessorHandler.h b/dbms/src/Flash/CoprocessorHandler.h index 477daeeb636..900d9d77fbe 100644 --- a/dbms/src/Flash/CoprocessorHandler.h +++ b/dbms/src/Flash/CoprocessorHandler.h @@ -17,9 +17,9 @@ struct CoprocessorContext { Context & db_context; const kvrpcpb::Context & kv_context; - grpc::ServerContext & grpc_server_context; + const grpc::ServerContext & grpc_server_context; - CoprocessorContext(Context & db_context_, const kvrpcpb::Context & kv_context_, grpc::ServerContext & grpc_server_context_) + CoprocessorContext(Context & db_context_, const kvrpcpb::Context & kv_context_, const grpc::ServerContext & grpc_server_context_) : db_context(db_context_), kv_context(kv_context_), grpc_server_context(grpc_server_context_) {} }; diff --git a/dbms/src/Flash/FlashService.cpp b/dbms/src/Flash/FlashService.cpp index e1f1cb76094..5a404b0d5f2 100644 --- a/dbms/src/Flash/FlashService.cpp +++ b/dbms/src/Flash/FlashService.cpp @@ -53,7 +53,8 @@ grpc::Status FlashService::BatchCommands( tikvpb::BatchCommandsResponse response; BatchCommandsContext batch_commands_context( - context, [this](grpc::ServerContext * grpc_server_context) { return createDBContext(grpc_server_context); }, *grpc_context); + context, [this](const grpc::ServerContext * grpc_server_context) { return createDBContext(grpc_server_context); }, + *grpc_context); BatchCommandsHandler batch_commands_handler(batch_commands_context, request, response); auto ret = batch_commands_handler.execute(); if (!ret.ok()) @@ -75,22 +76,20 @@ grpc::Status FlashService::BatchCommands( return grpc::Status::OK; } -String getClientMetaVarWithDefault(grpc::ServerContext * grpc_context, const String & name, const String & default_val) +String getClientMetaVarWithDefault(const grpc::ServerContext * grpc_context, const String & name, const String & default_val) { - if (grpc_context->client_metadata().count(name) != 1) - return default_val; - else - return String(grpc_context->client_metadata().find(name)->second.data()); + if (auto it = grpc_context->client_metadata().find(name); it != grpc_context->client_metadata().end()) + return it->second.data(); + return default_val; } -std::tuple FlashService::createDBContext(grpc::ServerContext * grpc_context) +std::tuple FlashService::createDBContext(const grpc::ServerContext * grpc_context) const { /// Create DB context. Context context = server.context(); context.setGlobalContext(server.context()); /// Set a bunch of client information. - auto client_meta = grpc_context->client_metadata(); String query_id = getClientMetaVarWithDefault(grpc_context, "query_id", ""); context.setCurrentQueryId(query_id); ClientInfo & client_info = context.getClientInfo(); diff --git a/dbms/src/Flash/FlashService.h b/dbms/src/Flash/FlashService.h index 15f33df8558..09e1640ab23 100644 --- a/dbms/src/Flash/FlashService.h +++ b/dbms/src/Flash/FlashService.h @@ -25,7 +25,7 @@ class FlashService final : public tikvpb::Tikv::Service, public std::enable_shar grpc::ServerReaderWriter * stream) override; private: - std::tuple createDBContext(grpc::ServerContext * grpc_contex); + std::tuple createDBContext(const grpc::ServerContext * grpc_contex) const; private: IServer & server;