diff --git a/.circleci/continue_config.yml b/.circleci/continue_config.yml index 15b7c63321dc2..5a4f220e04cd2 100644 --- a/.circleci/continue_config.yml +++ b/.circleci/continue_config.yml @@ -79,6 +79,14 @@ jobs: command: | cd presto-native-execution make velox-submodule + - run: + name: "Install JWT adapter dependency" + command: | + mkdir -p ${HOME}/adapter-deps/install + source /opt/rh/gcc-toolset-9/enable + set -xu + cd presto-native-execution + DEPENDENCY_DIR=${HOME}/adapter-deps PROMPT_ALWAYS_RESPOND=n ./scripts/setup-adapters.sh jwt - run: name: "Calculate merge-base date for CCache" command: git show -s --format=%cd --date="format:%Y%m%d" $(git merge-base origin/master HEAD) | tee merge-base-date @@ -102,6 +110,7 @@ jobs: -DCMAKE_BUILD_TYPE=Debug \ -DPRESTO_ENABLE_PARQUET=ON \ -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON \ + -DPRESTO_ENABLE_JWT=ON \ -DCMAKE_PREFIX_PATH=/usr/local \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ninja -C _build/debug -j 8 @@ -210,13 +219,14 @@ jobs: cd presto-native-execution make velox-submodule - run: - name: "Install S3 adapter dependencies" + name: "Install all adapter dependencies" command: | mkdir -p ${HOME}/adapter-deps/install source /opt/rh/gcc-toolset-9/enable set -xu cd presto-native-execution DEPENDENCY_DIR=${HOME}/adapter-deps PROMPT_ALWAYS_RESPOND=n ./velox/scripts/setup-adapters.sh + DEPENDENCY_DIR=${HOME}/adapter-deps PROMPT_ALWAYS_RESPOND=n ./scripts/setup-adapters.sh jwt - run: name: "Build All" command: | @@ -232,6 +242,7 @@ jobs: -DPRESTO_ENABLE_PARQUET=ON \ -DPRESTO_ENABLE_S3=ON \ -DPRESTO_ENABLE_REMOTE_FUNCTIONS=ON \ + -DPRESTO_ENABLE_JWT=ON \ -DCMAKE_PREFIX_PATH=/usr/local \ -DCMAKE_CXX_COMPILER_LAUNCHER=ccache ninja -C _build/release -j 8 diff --git a/presto-docs/src/main/sphinx/develop/presto-native.rst b/presto-docs/src/main/sphinx/develop/presto-native.rst index fcef643649fa8..dcb2acb47d589 100644 --- a/presto-docs/src/main/sphinx/develop/presto-native.rst +++ b/presto-docs/src/main/sphinx/develop/presto-native.rst @@ -21,7 +21,7 @@ HTTP endpoints related to tasks are registered to Proxygen in * POST: v1/task: This processes a `TaskUpdateRequest` * GET: v1/task: This returns a serialized `TaskInfo` (used for comprehensive - metrics, may be reported less frequently) + metrics, may be reported less frequently) * GET: v1/task/status: This returns a serialized `TaskStatus` (used for query progress tracking, must be reported frequently) @@ -104,5 +104,26 @@ The following properties allow the configuration of remote function execution: The UDS (unix domain socket) path to communicate with a local remote function server. If specified, takes precedence over - ``remote-function-server.thrift.address`` and + ``remote-function-server.thrift.address`` and ``remote-function-server.thrift.port``. + +JWT authentication support +-------------------------- + +Prestissimo supports JWT authentication for internal communication. +For details on the generally supported parameters visit `JWT <../security/internal-communication.html#jwt>`_. + +There is also an additional parameter: + +``internal-communication.jwt.expiration-seconds`` +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + + * **Type** ``integer`` + * **Default value:** ``300`` + + There is a time period between creating the JWT on the client + and verification by the server. + If the time period is less than or equal to the parameter value, the request + is valid. + If the time period exceeds the parameter value, the request is rejected as + authentication failure (HTTP 401). \ No newline at end of file diff --git a/presto-native-execution/CMakeLists.txt b/presto-native-execution/CMakeLists.txt index 59e46cc235617..6232f7323abc7 100644 --- a/presto-native-execution/CMakeLists.txt +++ b/presto-native-execution/CMakeLists.txt @@ -58,6 +58,8 @@ option(PRESTO_ENABLE_REMOTE_FUNCTIONS "Enable remote function support" OFF) option(PRESTO_ENABLE_TESTING "Enable tests" ON) +option(PRESTO_ENABLE_JWT "Enable JWT (JSON Web Token) authentication" OFF) + # Set all Velox options below add_compile_definitions(FOLLY_HAVE_INT128_T=1) @@ -201,4 +203,8 @@ else() add_definitions(-DVELOX_DISABLE_GOOGLETEST) endif() +if(PRESTO_ENABLE_JWT) + add_compile_definitions(PRESTO_ENABLE_JWT) +endif() + add_subdirectory(presto_cpp) diff --git a/presto-native-execution/Makefile b/presto-native-execution/Makefile index 4fc5f7bf806bc..6e13f534923de 100644 --- a/presto-native-execution/Makefile +++ b/presto-native-execution/Makefile @@ -25,6 +25,7 @@ PRESTO_ENABLE_PARQUET ?= "OFF" PRESTO_ENABLE_S3 ?= "OFF" PRESTO_ENABLE_HDFS ?= "OFF" PRESTO_ENABLE_REMOTE_FUNCTIONS ?= "OFF" +PRESTO_ENABLE_JWT ?= "OFF" EXTRA_CMAKE_FLAGS ?= "" CMAKE_FLAGS := -DTREAT_WARNINGS_AS_ERRORS=${TREAT_WARNINGS_AS_ERRORS} @@ -35,6 +36,7 @@ CMAKE_FLAGS += -DPRESTO_ENABLE_PARQUET=$(PRESTO_ENABLE_PARQUET) CMAKE_FLAGS += -DPRESTO_ENABLE_S3=$(PRESTO_ENABLE_S3) CMAKE_FLAGS += -DPRESTO_ENABLE_HDFS=$(PRESTO_ENABLE_HDFS) CMAKE_FLAGS += -DPRESTO_ENABLE_REMOTE_FUNCTIONS=$(PRESTO_ENABLE_REMOTE_FUNCTIONS) +CMAKE_FLAGS += -DPRESTO_ENABLE_JWT=$(PRESTO_ENABLE_JWT) SHELL := /bin/bash diff --git a/presto-native-execution/README.md b/presto-native-execution/README.md index 19355b7b69098..9bf359d5e6593 100644 --- a/presto-native-execution/README.md +++ b/presto-native-execution/README.md @@ -55,6 +55,15 @@ This dependency can be installed by running the script below from the `./velox/scripts/setup-adapters.sh aws` +To enable JWT authentication support, set `PRESTO_ENABLE_JWT = "ON"` in +the environment. + +JWT authentication support needs the [JWT CPP](https://github.com/Thalhammer/jwt-cpp) library. +This dependency can be installed by running the script below from the +`presto/presto-native-execution` directory. + +`./scripts/setup-adapters.sh jwt` + * After installing the above dependencies, from the `presto/presto-native-execution` directory, run `make` * For development, use diff --git a/presto-native-execution/presto_cpp/main/PrestoServer.cpp b/presto-native-execution/presto_cpp/main/PrestoServer.cpp index 64b9f787ca46a..e57a1bae7f2a5 100644 --- a/presto-native-execution/presto_cpp/main/PrestoServer.cpp +++ b/presto-native-execution/presto_cpp/main/PrestoServer.cpp @@ -30,6 +30,7 @@ #include "presto_cpp/main/http/HttpServer.h" #include "presto_cpp/main/http/filters/AccessLogFilter.h" #include "presto_cpp/main/http/filters/HttpEndpointLatencyFilter.h" +#include "presto_cpp/main/http/filters/InternalAuthenticationFilter.h" #include "presto_cpp/main/http/filters/StatsFilter.h" #include "presto_cpp/main/operators/BroadcastExchangeSource.h" #include "presto_cpp/main/operators/BroadcastWrite.h" @@ -163,6 +164,15 @@ void PrestoServer::run() { clientCertAndKeyPath = optionalClientCertPath.value(); } + if (systemConfig->internalCommunicationJwtEnabled()) { +#ifndef PRESTO_ENABLE_JWT + VELOX_USER_FAIL("Internal JWT is enabled but not supported"); +#endif + VELOX_USER_CHECK( + !(systemConfig->internalCommunicationSharedSecret().empty()), + "Internal JWT is enabled without a corresponding shared secret"); + } + nodeVersion_ = systemConfig->prestoVersion(); httpExecThreads = systemConfig->httpExecThreads(); environment_ = nodeConfig->nodeEnvironment(); @@ -658,6 +668,11 @@ PrestoServer::getHttpServerFilters() { httpServer_.get())); } + // Always add the authentication filter to make sure the worker configuration + // is in line with the overall cluster configuration e.g. cannot have a worker + // without JWT enabled. + filters.push_back( + std::make_unique()); return filters; } diff --git a/presto-native-execution/presto_cpp/main/common/Configs.cpp b/presto-native-execution/presto_cpp/main/common/Configs.cpp index 930a0faeeacaa..e2c4aeed7a985 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.cpp +++ b/presto-native-execution/presto_cpp/main/common/Configs.cpp @@ -279,6 +279,9 @@ SystemConfig::SystemConfig() { NUM_PROP(kTaskRunTimeSliceMicros, 50'000), BOOL_PROP(kIncludeNodeInSpillPath, false), NUM_PROP(kOldTaskCleanUpMs, 60'000), + STR_PROP(kInternalCommunicationJwtEnabled, "false"), + STR_PROP(kInternalCommunicationSharedSecret, ""), + NUM_PROP(kInternalCommunicationJwtExpirationSeconds, 300), }; } @@ -558,6 +561,21 @@ int32_t SystemConfig::oldTaskCleanUpMs() const { return optionalProperty(kOldTaskCleanUpMs).value(); } +// The next three toggles govern the use of JWT for authentication +// for communication between the cluster nodes. +bool SystemConfig::internalCommunicationJwtEnabled() const { + return optionalProperty(kInternalCommunicationJwtEnabled).value(); +} + +std::string SystemConfig::internalCommunicationSharedSecret() const { + return optionalProperty(kInternalCommunicationSharedSecret).value(); +} + +int32_t SystemConfig::internalCommunicationJwtExpirationSeconds() const { + return optionalProperty(kInternalCommunicationJwtExpirationSeconds) + .value(); +} + NodeConfig::NodeConfig() { registeredProps_ = std::unordered_map>{ diff --git a/presto-native-execution/presto_cpp/main/common/Configs.h b/presto-native-execution/presto_cpp/main/common/Configs.h index 2f989a45c48b6..588d598672ed1 100644 --- a/presto-native-execution/presto_cpp/main/common/Configs.h +++ b/presto-native-execution/presto_cpp/main/common/Configs.h @@ -320,6 +320,14 @@ class SystemConfig : public ConfigBase { static constexpr std::string_view kRemoteFunctionServerCatalogName{ "remote-function-server.catalog-name"}; + /// Options to configure the internal (in-cluster) JWT authentication. + static constexpr std::string_view kInternalCommunicationJwtEnabled{ + "internal-communication.jwt.enabled"}; + static constexpr std::string_view kInternalCommunicationSharedSecret{ + "internal-communication.shared-secret"}; + static constexpr std::string_view kInternalCommunicationJwtExpirationSeconds{ + "internal-communication.jwt.expiration-seconds"}; + SystemConfig(); static SystemConfig* instance(); @@ -459,6 +467,12 @@ class SystemConfig : public ConfigBase { bool includeNodeInSpillPath() const; int32_t oldTaskCleanUpMs() const; + + bool internalCommunicationJwtEnabled() const; + + std::string internalCommunicationSharedSecret() const; + + int32_t internalCommunicationJwtExpirationSeconds() const; }; /// Provides access to node properties defined in node.properties file. diff --git a/presto-native-execution/presto_cpp/main/http/CMakeLists.txt b/presto-native-execution/presto_cpp/main/http/CMakeLists.txt index 565fabfa23203..614c02c7e3a39 100644 --- a/presto-native-execution/presto_cpp/main/http/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/http/CMakeLists.txt @@ -11,6 +11,12 @@ # limitations under the License. add_library(presto_http HttpClient.cpp HttpServer.cpp) +if(PRESTO_ENABLE_JWT) + add_compile_definitions(JWT_DISABLE_PICOJSON) + target_include_directories( + presto_http PRIVATE ${CMAKE_SOURCE_DIR}/presto_cpp/external/json) +endif() + add_subdirectory(filters) target_link_libraries( diff --git a/presto-native-execution/presto_cpp/main/http/HttpClient.cpp b/presto-native-execution/presto_cpp/main/http/HttpClient.cpp index cde6769887117..4b0a7fd71d286 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpClient.cpp +++ b/presto-native-execution/presto_cpp/main/http/HttpClient.cpp @@ -11,8 +11,12 @@ * See the License for the specific language governing permissions and * limitations under the License. */ +#ifdef PRESTO_ENABLE_JWT +#include // @manual +#include // @manual +#include //@manual +#endif // PRESTO_ENABLE_JWT #include - #include "presto_cpp/main/common/Configs.h" #include "presto_cpp/main/http/HttpClient.h" @@ -346,4 +350,32 @@ folly::SemiFuture> HttpClient::sendRequest( return future; } +void RequestBuilder::addJwtIfConfigured() { +#ifdef PRESTO_ENABLE_JWT + if (SystemConfig::instance()->internalCommunicationJwtEnabled()) { + // If JWT was enabled the secret cannot be empty. + auto secretHash = std::vector(SHA256_DIGEST_LENGTH); + folly::ssl::OpenSSLHash::sha256( + folly::range(secretHash), + folly::ByteRange(folly::StringPiece( + SystemConfig::instance()->internalCommunicationSharedSecret()))); + + const auto time = std::chrono::system_clock::now(); + const auto token = + jwt::create() + .set_subject(NodeConfig::instance()->nodeId()) + .set_issued_at(time) + .set_expires_at( + time + + std::chrono::seconds{ + SystemConfig::instance() + ->internalCommunicationJwtExpirationSeconds()}) + .sign(jwt::algorithm::hs256{std::string( + reinterpret_cast(secretHash.data()), + secretHash.size())}); + header(kPrestoInternalBearer, token); + } +#endif // PRESTO_ENABLE_JWT +} + } // namespace facebook::presto::http diff --git a/presto-native-execution/presto_cpp/main/http/HttpClient.h b/presto-native-execution/presto_cpp/main/http/HttpClient.h index 68f07739fcfba..3815608298c9f 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpClient.h +++ b/presto-native-execution/presto_cpp/main/http/HttpClient.h @@ -169,12 +169,15 @@ class RequestBuilder { folly::SemiFuture> send(HttpClient* client, const std::string& body = "", int64_t delayMs = 0) { + addJwtIfConfigured(); header(proxygen::HTTP_HEADER_CONTENT_LENGTH, std::to_string(body.size())); headers_.ensureHostHeader(); return client->sendRequest(headers_, body, delayMs); } private: + void addJwtIfConfigured(); + proxygen::HTTPMessage headers_; }; diff --git a/presto-native-execution/presto_cpp/main/http/HttpConstants.h b/presto-native-execution/presto_cpp/main/http/HttpConstants.h index 33bd0d92b82db..55c1d41266c9c 100644 --- a/presto-native-execution/presto_cpp/main/http/HttpConstants.h +++ b/presto-native-execution/presto_cpp/main/http/HttpConstants.h @@ -17,9 +17,11 @@ namespace facebook::presto::http { const uint16_t kHttpOk = 200; const uint16_t kHttpAccepted = 202; const uint16_t kHttpNoContent = 204; +const uint16_t kHttpUnauthorized = 401; const uint16_t kHttpNotFound = 404; const uint16_t kHttpInternalServerError = 500; const char kMimeTypeApplicationJson[] = "application/json"; const char kMimeTypeApplicationThrift[] = "application/x-thrift+binary"; +static const char kPrestoInternalBearer[] = "X-Presto-Internal-Bearer"; } // namespace facebook::presto::http diff --git a/presto-native-execution/presto_cpp/main/http/filters/CMakeLists.txt b/presto-native-execution/presto_cpp/main/http/filters/CMakeLists.txt index 55e9e50fd8db3..d0f97d0321c47 100644 --- a/presto-native-execution/presto_cpp/main/http/filters/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/http/filters/CMakeLists.txt @@ -11,7 +11,12 @@ # limitations under the License. add_library(http_filters AccessLogFilter.cpp HttpEndpointLatencyFilter.cpp - StatsFilter.cpp) + InternalAuthenticationFilter.cpp StatsFilter.cpp) + +if(PRESTO_ENABLE_JWT) + target_include_directories( + http_filters PRIVATE ${CMAKE_SOURCE_DIR}/presto_cpp/external/json) +endif() target_link_libraries(http_filters presto_common ${PROXYGEN_LIBRARIES}) diff --git a/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.cpp b/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.cpp new file mode 100644 index 0000000000000..860bcd8d78976 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.cpp @@ -0,0 +1,165 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#include "presto_cpp/main/http/filters/InternalAuthenticationFilter.h" +#ifdef PRESTO_ENABLE_JWT +#include //@manual +#include //@manual +#include //@manual +#endif // PRESTO_ENABLE_JWT +#include +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/http/HttpConstants.h" + +namespace facebook::presto::http::filters { + +/// The filter is always enabled by the presto server. +/// Therefore, it is part of processing every request. +/// This enables detecting a misconfiguration if +/// the JWT is present in the request but cannot be +/// processed due to a missing configuration. +/// This filter lets requests pass through the stages of processing +/// provided the JWT in the request has been validated. +/// If the request is rejected on the first stage (onRequest), +/// upstream is notified of the error and the request chain +/// for the subsequent stages ends processing in this filter. +void InternalAuthenticationFilter::onRequest( + std::unique_ptr msg) noexcept { + auto token = msg->getHeaders().getSingleOrEmpty(kPrestoInternalBearer); + if (!token.empty() && + !SystemConfig::instance()->internalCommunicationJwtEnabled()) { + /// Error - catch if cluster uses token but this server is not configured + /// properly. + sendUnauthorizedResponse(); + return; + } + + if (token.empty() && + SystemConfig::instance()->internalCommunicationJwtEnabled()) { + // Error - require a token to be in the request. + sendUnauthorizedResponse(); + return; + } + + if (token.empty()) { + // Forward the request. + Filter::onRequest(std::move(msg)); + } else { + processAndVerifyJwt(token, std::move(msg)); + } +} + +void InternalAuthenticationFilter::onBody( + std::unique_ptr body) noexcept { + // Passthrough if request was not rejected. + if (upstream_) { + // forward + upstream_->onBody(std::move(body)); + } +} + +void InternalAuthenticationFilter::onEOM() noexcept { + // Passthrough if request was not rejected. + if (upstream_) { + // forward + upstream_->onEOM(); + } +} + +void InternalAuthenticationFilter::requestComplete() noexcept { + // Passthrough if request was not rejected. + if (upstream_) { + upstream_->requestComplete(); + } + delete this; +} + +void InternalAuthenticationFilter::onUpgrade( + proxygen::UpgradeProtocol protocol) noexcept { + // Passthrough if request was not rejected. + if (upstream_) { + upstream_->onUpgrade(protocol); + } +} + +void InternalAuthenticationFilter::onError( + proxygen::ProxygenError err) noexcept { + // If onError is invoked before we forward the error. + if (upstream_) { + upstream_->onError(err); + upstream_ = nullptr; + } + delete this; +} + +void InternalAuthenticationFilter::sendGenericErrorResponse(void) { + /// Indicate to upstream an error occurred and make sure + /// no further forwarding occurs. + upstream_->onError(proxygen::kErrorUnsupportedExpectation); + upstream_ = nullptr; + + proxygen::ResponseBuilder(downstream_) + .status(kHttpInternalServerError, "Internal Server Error") + .sendWithEOM(); +} + +void InternalAuthenticationFilter::sendUnauthorizedResponse() { + /// Indicate to upstream an error occurred and make sure + /// no further processing occurs. + upstream_->onError(proxygen::kErrorUnauthorized); + upstream_ = nullptr; + + proxygen::ResponseBuilder(downstream_) + .status(kHttpUnauthorized, "Unauthorized") + .sendWithEOM(); +} + +void InternalAuthenticationFilter::processAndVerifyJwt( + const std::string& token, + std::unique_ptr msg) { +#ifdef PRESTO_ENABLE_JWT + try { + // Build the signature key from the secret and validate the signature. + auto secretHash = std::vector(SHA256_DIGEST_LENGTH); + folly::ssl::OpenSSLHash::sha256( + folly::range(secretHash), + folly::ByteRange(folly::StringPiece( + SystemConfig::instance()->internalCommunicationSharedSecret()))); + + // Decode and verify the JWT. + auto decodedJwt = jwt::decode(token); + auto verifier = jwt::verify().allow_algorithm( + jwt::algorithm::hs256{std::string( + reinterpret_cast(secretHash.data()), secretHash.size())}); + verifier.verify(decodedJwt); + + auto subject = decodedJwt.get_subject(); + // The nodeId of the requester is the subject. Check if it was set. + if (subject.empty()) { + std::error_code ec{jwt::error::token_verification_error::missing_claim}; + throw jwt::error::token_verification_exception(ec); + } + // Passed the verification, move the message along. + Filter::onRequest(std::move(msg)); + } catch (const jwt::error::token_verification_exception& e) { + sendUnauthorizedResponse(); + } catch (const jwt::error::signature_verification_exception& e) { + sendUnauthorizedResponse(); + } catch (const std::system_error& e) { + sendGenericErrorResponse(); + } +#endif // PRESTO_ENABLE_JWT +} + +} // namespace facebook::presto::http::filters diff --git a/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.h b/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.h new file mode 100644 index 0000000000000..ca4e6de7d9845 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/http/filters/InternalAuthenticationFilter.h @@ -0,0 +1,77 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +#pragma once + +#include +#include + +namespace facebook::presto::http::filters { + +class InternalAuthenticationFilter : public proxygen::Filter { + public: + explicit InternalAuthenticationFilter(proxygen::RequestHandler* upstream) + : Filter(upstream), requestRejected_(false) {} + + /// For details on the filter request handling see Filters.h and + /// RequestHandler.h. + + /// This filter is called to process each stage of the request + /// processing. + /// For each stage it has to be decided if the request is passed on or + /// if it has been rejected. The request can be rejected at the + /// "onRequest" stage in which case upstream is notified of an error + /// occurring. Subsequently, the remaining stages for this filter are called. + /// However, if the request was already rejected nothing is passed through. + void onRequest(std::unique_ptr msg) noexcept override; + + void requestComplete() noexcept override; + + void onBody(std::unique_ptr body) noexcept override; + + void onEOM() noexcept override; + + void onUpgrade(proxygen::UpgradeProtocol protocol) noexcept override; + + void onError(proxygen::ProxygenError err) noexcept override; + + private: + void sendGenericErrorResponse(void); + + void sendUnauthorizedResponse(void); + + void processAndVerifyJwt( + const std::string& token, + std::unique_ptr msg); + + bool requestRejected_; +}; + +class InternalAuthenticationFilterFactory + : public proxygen::RequestHandlerFactory { + public: + explicit InternalAuthenticationFilterFactory() {} + + void onServerStart(folly::EventBase* /*evb*/) noexcept override {} + + void onServerStop() noexcept override {} + + proxygen::RequestHandler* onRequest( + proxygen::RequestHandler* handler, + proxygen::HTTPMessage*) noexcept override { + return new InternalAuthenticationFilter(handler); + } +}; + +} // namespace facebook::presto::http::filters diff --git a/presto-native-execution/presto_cpp/main/http/tests/CMakeLists.txt b/presto-native-execution/presto_cpp/main/http/tests/CMakeLists.txt index 9733399acdcc5..ef561f3f2b01c 100644 --- a/presto-native-execution/presto_cpp/main/http/tests/CMakeLists.txt +++ b/presto-native-execution/presto_cpp/main/http/tests/CMakeLists.txt @@ -17,3 +17,14 @@ add_test( WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) target_link_libraries(presto_http_test presto_http gtest gtest_main) + +if(PRESTO_ENABLE_JWT) + add_executable(presto_http_jwt_test HttpJwtTest.cpp) + + add_test( + NAME presto_http_jwt_test + COMMAND presto_http_jwt_test + WORKING_DIRECTORY ${CMAKE_CURRENT_SOURCE_DIR}) + + target_link_libraries(presto_http_jwt_test presto_http gtest gtest_main) +endif() diff --git a/presto-native-execution/presto_cpp/main/http/tests/HttpJwtTest.cpp b/presto-native-execution/presto_cpp/main/http/tests/HttpJwtTest.cpp new file mode 100644 index 0000000000000..d3664f1cc53cf --- /dev/null +++ b/presto-native-execution/presto_cpp/main/http/tests/HttpJwtTest.cpp @@ -0,0 +1,192 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include "presto_cpp/main/http/filters/InternalAuthenticationFilter.h" +#include "presto_cpp/main/http/tests/HttpTestBase.h" + +namespace fs = boost::filesystem; + +using namespace facebook::presto; +using namespace facebook::velox; +using namespace facebook::velox::memory; + +int main(int argc, char** argv) { + testing::InitGoogleTest(&argc, argv); + folly::init(&argc, &argv, true); + return RUN_ALL_TESTS(); +} + +class HttpJwtTestSuite : public ::testing::TestWithParam { + public: + explicit HttpJwtTestSuite() { + setupJwtNodeConfig(); + } + + protected: + std::unique_ptr jwtSystemConfig( + const std::unordered_map configOverride = {}) + const { + std::unordered_map systemConfig{ + {std::string(SystemConfig::kMutableConfig), std::string("true")}, + {std::string(SystemConfig::kInternalCommunicationJwtEnabled), + std::string("true")}, + {std::string(SystemConfig::kInternalCommunicationSharedSecret), + "mysecret"}}; + + // Update the default config map with the supplied configOverride map + for (const auto& [configName, configValue] : configOverride) { + systemConfig[configName] = configValue; + } + + return std::make_unique(systemConfig); + } + + void setupJwtNodeConfig() const { + std::unordered_map nodeConfigValues{ + {std::string(NodeConfig::kMutableConfig), std::string("true")}, + {std::string(NodeConfig::kNodeId), std::string("testnode")}}; + std::unique_ptr rawNodeConfig = + std::make_unique(nodeConfigValues); + NodeConfig::instance()->initialize(std::move(rawNodeConfig)); + } + + std::unique_ptr produceHttpResponse( + const bool useHttps, + std::unordered_map clientSystemConfigOverride = + {}, + std::unordered_map serverSystemConfigOverride = + {}, + const uint64_t sendDelayMs = 500) { + auto memoryPool = defaultMemoryManager().addLeafPool("HttpJwtTestSuite"); + auto clientConfig = jwtSystemConfig(clientSystemConfigOverride); + auto systemConfig = SystemConfig::instance(); + systemConfig->initialize(std::move(clientConfig)); + + auto server = getServer(useHttps); + + auto request = std::make_shared(); + + // Set the delay to 1s to have enough time to update the config. + server->registerGet("/async/msg", asyncMsg(request)); + + std::vector> filters; + filters.push_back( + std::make_unique()); + HttpServerWrapper wrapper(std::move(server)); + wrapper.setFilters(filters); + auto serverAddress = wrapper.start().get(); + + HttpClientFactory clientFactory; + // Make sure client doesn't timeout for a delayed request. + auto client = clientFactory.newClient( + serverAddress, std::chrono::milliseconds(1'000), useHttps, memoryPool); + + auto [reqPromise, reqFuture] = folly::makePromiseContract(); + request->requestPromise = std::move(reqPromise); + + auto responseFuture = + sendGet(client.get(), "/async/msg", sendDelayMs, "TestBody"); + + auto serverConfig = jwtSystemConfig(serverSystemConfigOverride); + auto valuesMap = serverConfig->valuesCopy(); + /// The request is delayed. Meanwhile update the config so the server + /// might use a different configuration, if provided. + for (auto key : valuesMap) { + systemConfig->setValue(key.first, key.second); + } + + /// Wait and wake up after 1s, if needed, because the server should have + /// rejected the request or processed the request already. + std::move(reqFuture).wait(std::chrono::milliseconds(1'000)); + + if (auto msgPromise = request->msgPromise.lock()) { + msgPromise->promise.setValue("Success"); + } + + auto response = std::move(responseFuture).get(); + wrapper.stop(); + + return response; + } +}; + +TEST_P(HttpJwtTestSuite, basicJwtTest) { + const bool useHttps = GetParam(); + + auto response = std::move(produceHttpResponse(useHttps)); + + EXPECT_EQ(response->headers()->getStatusCode(), http::kHttpOk); +} + +TEST_P(HttpJwtTestSuite, jwtSecretMismatch) { + std::unordered_map serverConfigOverride{ + {std::string(SystemConfig::kInternalCommunicationSharedSecret), + "falseSecret"}}; + + const bool useHttps = GetParam(); + + auto response = + std::move(produceHttpResponse(useHttps, {}, serverConfigOverride)); + + EXPECT_EQ(response->headers()->getStatusCode(), http::kHttpUnauthorized); +} + +/// This test tests when a token arrives, is processed, and is rejected +/// if past the expiration time. +TEST_P(HttpJwtTestSuite, jwtExpiredToken) { + const uint64_t kSendDelay{1'500}; + + std::unordered_map clientConfigOverride{ + {std::string(SystemConfig::kInternalCommunicationJwtExpirationSeconds), + std::string("1")}}; // expire after 1s, delay is 1.5s. + + const bool useHttps = GetParam(); + + auto response = std::move( + produceHttpResponse(useHttps, clientConfigOverride, {}, kSendDelay)); + + EXPECT_EQ(response->headers()->getStatusCode(), http::kHttpUnauthorized); +} + +// Test catching a misconfiguration when a request with a JWT is received. +TEST_P(HttpJwtTestSuite, jwtServerVerificationDisabled) { + std::unordered_map serverConfigOverride{ + {std::string(SystemConfig::kInternalCommunicationJwtEnabled), + std::string("false")}}; + + const bool useHttps = GetParam(); + + auto response = + std::move(produceHttpResponse(useHttps, {}, serverConfigOverride)); + + EXPECT_EQ(response->headers()->getStatusCode(), http::kHttpUnauthorized); +} + +// Test missing client JWT. +TEST_P(HttpJwtTestSuite, jwtClientMissingJwt) { + std::unordered_map clientConfigOverride{ + {std::string(SystemConfig::kInternalCommunicationJwtEnabled), + std::string("false")}}; + + const bool useHttps = GetParam(); + + auto response = + std::move(produceHttpResponse(useHttps, clientConfigOverride)); + + EXPECT_EQ(response->headers()->getStatusCode(), http::kHttpUnauthorized); +} + +INSTANTIATE_TEST_CASE_P( + HTTPJwtTest, + HttpJwtTestSuite, + ::testing::Values(true, false)); diff --git a/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp b/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp index 70713967edc73..0d3fcb9cb207f 100644 --- a/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp +++ b/presto-native-execution/presto_cpp/main/http/tests/HttpTest.cpp @@ -11,22 +11,7 @@ * See the License for the specific language governing permissions and * limitations under the License. */ -#include -#include -#include -#include -#include -#include -#include -#include "presto_cpp/main/http/HttpClient.h" -#include "presto_cpp/main/http/HttpServer.h" -#include "velox/common/base/StatsReporter.h" - -namespace fs = boost::filesystem; - -using namespace facebook::presto; -using namespace facebook::velox; -using namespace facebook::velox::memory; +#include "presto_cpp/main/http/tests/HttpTestBase.h" int main(int argc, char** argv) { testing::InitGoogleTest(&argc, argv); @@ -34,211 +19,6 @@ int main(int argc, char** argv) { return RUN_ALL_TESTS(); } -namespace { - -std::string getCertsPath(const std::string& fileName) { - std::string currentPath = fs::current_path().c_str(); - if (boost::algorithm::ends_with(currentPath, "fbcode")) { - return currentPath + - "/github/presto-trunk/presto-native-execution/presto_cpp/main/http/tests/certs/" + - fileName; - } - - // CLion runs the tests from cmake-build-release/ or cmake-build-debug/ - // directory. Hard-coded json files are not copied there and test fails with - // file not found. Fixing the path so that we can trigger these tests from - // CLion. - boost::algorithm::replace_all(currentPath, "cmake-build-release/", ""); - boost::algorithm::replace_all(currentPath, "cmake-build-debug/", ""); - - return currentPath + "/certs/" + fileName; -} - -class HttpServerWrapper { - public: - explicit HttpServerWrapper(std::unique_ptr server) - : server_(std::move(server)) {} - - ~HttpServerWrapper() { - stop(); - } - - folly::SemiFuture start() { - auto [promise, future] = folly::makePromiseContract(); - promise_ = std::move(promise); - serverThread_ = std::make_unique([this]() { - server_->start({}, [&](proxygen::HTTPServer* httpServer) { - ASSERT_EQ(httpServer->addresses().size(), 1); - promise_.setValue(httpServer->addresses()[0].address); - }); - }); - - return std::move(future); - } - - void stop() { - if (serverThread_) { - server_->stop(); - serverThread_->join(); - serverThread_.reset(); - } - } - - private: - std::unique_ptr server_; - std::unique_ptr serverThread_; - folly::Promise promise_; -}; - -// Async SSL connection callback which auto close the socket on success for -// test. -class AsyncSSLSockAutoCloseCallback - : public folly::AsyncSocket::ConnectCallback { - public: - explicit AsyncSSLSockAutoCloseCallback(folly::AsyncSSLSocket* sock) - : sock_(sock) {} - - void connectSuccess() noexcept override { - succeeded_ = true; - sock_->close(); - } - - void connectErr(const folly::AsyncSocketException&) noexcept override { - succeeded_ = false; - } - - bool succeeded() const { - return succeeded_; - } - - private: - folly::AsyncSSLSocket* const sock_{nullptr}; - bool succeeded_{false}; -}; - -void ping( - proxygen::HTTPMessage* /*message*/, - std::vector>& /*body*/, - proxygen::ResponseHandler* downstream) { - proxygen::ResponseBuilder(downstream).status(http::kHttpOk, "").sendWithEOM(); -} - -void blackhole( - proxygen::HTTPMessage* /*message*/, - std::vector>& /*body*/, - proxygen::ResponseHandler* downstream) {} - -std::string bodyAsString(http::HttpResponse& response, MemoryPool* pool) { - EXPECT_FALSE(response.hasError()); - std::ostringstream oss; - auto iobufs = response.consumeBody(); - for (auto& body : iobufs) { - oss << std::string((const char*)body->data(), body->length()); - pool->free(body->writableData(), body->capacity()); - } - EXPECT_EQ(pool->currentBytes(), 0); - return oss.str(); -} - -std::string toString(std::vector>& bufs) { - std::ostringstream oss; - for (auto& buf : bufs) { - oss << std::string((const char*)buf->data(), buf->length()); - } - return oss.str(); -} - -void echo( - proxygen::HTTPMessage* message, - std::vector>& body, - proxygen::ResponseHandler* downstream) { - if (body.empty()) { - proxygen::ResponseBuilder(downstream) - .status(http::kHttpOk, "") - .body(folly::IOBuf::wrapBuffer( - message->getURL().c_str(), message->getURL().size())) - .sendWithEOM(); - return; - } - - proxygen::ResponseBuilder(downstream) - .status(http::kHttpOk, "") - .header(proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") - .body(toString(body)) - .sendWithEOM(); -} - -class HttpClientFactory { - public: - HttpClientFactory() : eventBase_(std::make_unique()) { - eventBaseThread_ = - std::make_unique([&]() { eventBase_->loopForever(); }); - } - - ~HttpClientFactory() { - eventBase_->terminateLoopSoon(); - eventBaseThread_->join(); - } - - std::shared_ptr newClient( - const folly::SocketAddress& address, - const std::chrono::milliseconds& timeout, - bool useHttps, - std::shared_ptr pool, - std::function&& reportOnBodyStatsFunc = nullptr) { - if (useHttps) { - std::string clientCaPath = getCertsPath("client_ca.pem"); - std::string ciphers = "AES128-SHA,AES128-SHA256,AES256-GCM-SHA384"; - return std::make_shared( - eventBase_.get(), - address, - timeout, - pool, - clientCaPath, - ciphers, - std::move(reportOnBodyStatsFunc)); - } else { - return std::make_shared( - eventBase_.get(), - address, - timeout, - pool, - "", - "", - std::move(reportOnBodyStatsFunc)); - } - } - - private: - std::unique_ptr eventBase_; - std::unique_ptr eventBaseThread_; -}; - -folly::SemiFuture> sendGet( - http::HttpClient* client, - const std::string& url) { - return http::RequestBuilder() - .method(proxygen::HTTPMethod::GET) - .url(url) - .send(client); -} - -static std::unique_ptr getServer(bool useHttps) { - if (useHttps) { - std::string certPath = getCertsPath("test_cert1.pem"); - std::string keyPath = getCertsPath("test_key1.pem"); - std::string ciphers = "AES128-SHA,AES128-SHA256,AES256-GCM-SHA384"; - auto httpsConfig = std::make_unique( - folly::SocketAddress("127.0.0.1", 0), certPath, keyPath, ciphers); - return std::make_unique(nullptr, std::move(httpsConfig)); - } else { - return std::make_unique( - std::make_unique( - folly::SocketAddress("127.0.0.1", 0))); - } -} -} // namespace - class HttpsBasicTest : public ::testing::Test {}; TEST_F(HttpsBasicTest, ssl) { @@ -387,83 +167,6 @@ TEST_P(HttpTestSuite, serverRestart) { wrapper->stop(); } -namespace { -struct StringPromise { - explicit StringPromise(folly::Promise p) - : promise(std::move(p)) {} - folly::Promise promise; -}; - -enum RequestStatus { kStatusUnknown, kStatusInvalid, kStatusValid }; - -struct AsyncMsgRequestState { - folly::Promise requestPromise; - uint64_t maxWaitMillis{0}; - std::weak_ptr msgPromise; - RequestStatus requestStatus{kStatusUnknown}; -}; - -http::EndpointRequestHandlerFactory asyncMsg( - std::shared_ptr request) { - return [request]( - proxygen::HTTPMessage* /* message */, - const std::vector& /* args */) { - return new http::CallbackRequestHandler( - [request]( - proxygen::HTTPMessage* /*message*/, - const std::vector>& /*body*/, - proxygen::ResponseHandler* downstream, - std::shared_ptr handlerState) { - auto [promise, future] = folly::makePromiseContract(); - auto eventBase = folly::EventBaseManager::get()->getEventBase(); - auto maxWaitMillis = request->maxWaitMillis; - if (maxWaitMillis == 0) { - maxWaitMillis = 1'000'000'000; - } - - std::move(future) - .via(eventBase) - .onTimeout( - std::chrono::milliseconds(maxWaitMillis), - []() { return std::string("Timedout"); }) - .thenValue([downstream, handlerState, request](std::string msg) { - if (!handlerState->requestExpired()) { - request->requestStatus = kStatusValid; - proxygen::ResponseBuilder(downstream) - .status(http::kHttpOk, "") - .header(proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") - .body(msg) - .sendWithEOM(); - } else { - request->requestStatus = kStatusInvalid; - } - }) - .thenError( - folly::tag_t{}, - [downstream, handlerState, request](std::exception const& e) { - if (!handlerState->requestExpired()) { - request->requestStatus = kStatusValid; - proxygen::ResponseBuilder(downstream) - .status(http::kHttpInternalServerError, "") - .header( - proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") - .body(e.what()) - .sendWithEOM(); - } else { - request->requestStatus = kStatusInvalid; - } - }); - auto promiseHolder = - std::make_shared(std::move(promise)); - handlerState->runOnFinalization( - [promiseHolder]() mutable { promiseHolder.reset(); }); - request->msgPromise = folly::to_weak_ptr(promiseHolder); - request->requestPromise.setValue(true); - }); - }; -} -} // namespace - TEST_P(HttpTestSuite, asyncRequests) { auto memoryPool = defaultMemoryManager().addLeafPool("asyncRequests"); diff --git a/presto-native-execution/presto_cpp/main/http/tests/HttpTestBase.h b/presto-native-execution/presto_cpp/main/http/tests/HttpTestBase.h new file mode 100644 index 0000000000000..4015a5d229468 --- /dev/null +++ b/presto-native-execution/presto_cpp/main/http/tests/HttpTestBase.h @@ -0,0 +1,326 @@ +/* + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +#include +#include +#include +#include +#include +#include +#include +#include "presto_cpp/main/common/Configs.h" +#include "presto_cpp/main/http/HttpClient.h" +#include "presto_cpp/main/http/HttpServer.h" +#include "velox/common/base/StatsReporter.h" + +namespace fs = boost::filesystem; + +using namespace facebook::presto; +using namespace facebook::velox; +using namespace facebook::velox::memory; + +namespace { + +std::string getCertsPath(const std::string& fileName) { + std::string currentPath = fs::current_path().c_str(); + if (boost::algorithm::ends_with(currentPath, "fbcode")) { + return currentPath + + "/github/presto-trunk/presto-native-execution/presto_cpp/main/http/tests/certs/" + + fileName; + } + + // CLion runs the tests from cmake-build-release/ or cmake-build-debug/ + // directory. Hard-coded json files are not copied there and test fails with + // file not found. Fixing the path so that we can trigger these tests from + // CLion. + boost::algorithm::replace_all(currentPath, "cmake-build-release/", ""); + boost::algorithm::replace_all(currentPath, "cmake-build-debug/", ""); + + // As with building/testing using CLion when using a manual build and + // running the test from the build path the certs are not found and + // their path must be updated. + // The path is used by CMake. + boost::algorithm::replace_all(currentPath, "_build/debug/", ""); + boost::algorithm::replace_all(currentPath, "_build/release/", ""); + + return currentPath + "/certs/" + fileName; +} + +class HttpServerWrapper { + public: + explicit HttpServerWrapper(std::unique_ptr server) + : server_(std::move(server)) {} + + ~HttpServerWrapper() { + stop(); + } + + folly::SemiFuture start() { + auto [promise, future] = folly::makePromiseContract(); + promise_ = std::move(promise); + serverThread_ = std::make_unique([this]() { + server_->start( + std::move(filters_), [&](proxygen::HTTPServer* httpServer) { + ASSERT_EQ(httpServer->addresses().size(), 1); + promise_.setValue(httpServer->addresses()[0].address); + }); + }); + + return std::move(future); + } + + void stop() { + if (serverThread_) { + server_->stop(); + serverThread_->join(); + serverThread_.reset(); + } + } + + void setFilters( + std::vector>& filters) { + filters_ = std::move(filters); + } + + private: + std::unique_ptr server_; + std::unique_ptr serverThread_; + folly::Promise promise_; + std::vector> filters_ = {}; +}; + +// Async SSL connection callback which auto close the socket on success for +// test. +class AsyncSSLSockAutoCloseCallback + : public folly::AsyncSocket::ConnectCallback { + public: + explicit AsyncSSLSockAutoCloseCallback(folly::AsyncSSLSocket* sock) + : sock_(sock) {} + + void connectSuccess() noexcept override { + succeeded_ = true; + sock_->close(); + } + + void connectErr(const folly::AsyncSocketException&) noexcept override { + succeeded_ = false; + } + + bool succeeded() const { + return succeeded_; + } + + private: + folly::AsyncSSLSocket* const sock_{nullptr}; + bool succeeded_{false}; +}; + +void ping( + proxygen::HTTPMessage* /*message*/, + std::vector>& /*body*/, + proxygen::ResponseHandler* downstream) { + proxygen::ResponseBuilder(downstream).status(http::kHttpOk, "").sendWithEOM(); +} + +void blackhole( + proxygen::HTTPMessage* /*message*/, + std::vector>& /*body*/, + proxygen::ResponseHandler* downstream) {} + +std::string bodyAsString(http::HttpResponse& response, MemoryPool* pool) { + EXPECT_FALSE(response.hasError()); + std::ostringstream oss; + auto iobufs = response.consumeBody(); + for (auto& body : iobufs) { + oss << std::string((const char*)body->data(), body->length()); + pool->free(body->writableData(), body->capacity()); + } + EXPECT_EQ(pool->currentBytes(), 0); + return oss.str(); +} + +std::string toString(std::vector>& bufs) { + std::ostringstream oss; + for (auto& buf : bufs) { + oss << std::string((const char*)buf->data(), buf->length()); + } + return oss.str(); +} + +void echo( + proxygen::HTTPMessage* message, + std::vector>& body, + proxygen::ResponseHandler* downstream) { + if (body.empty()) { + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "") + .body(folly::IOBuf::wrapBuffer( + message->getURL().c_str(), message->getURL().size())) + .sendWithEOM(); + return; + } + + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "") + .header(proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") + .body(toString(body)) + .sendWithEOM(); +} + +class HttpClientFactory { + public: + HttpClientFactory() : eventBase_(std::make_unique()) { + eventBaseThread_ = + std::make_unique([&]() { eventBase_->loopForever(); }); + } + + ~HttpClientFactory() { + eventBase_->terminateLoopSoon(); + eventBaseThread_->join(); + } + + std::shared_ptr newClient( + const folly::SocketAddress& address, + const std::chrono::milliseconds& timeout, + bool useHttps, + std::shared_ptr pool, + std::function&& reportOnBodyStatsFunc = nullptr) { + if (useHttps) { + std::string clientCaPath = getCertsPath("client_ca.pem"); + std::string ciphers = "AES128-SHA,AES128-SHA256,AES256-GCM-SHA384"; + return std::make_shared( + eventBase_.get(), + address, + timeout, + pool, + clientCaPath, + ciphers, + std::move(reportOnBodyStatsFunc)); + } else { + return std::make_shared( + eventBase_.get(), + address, + timeout, + pool, + "", + "", + std::move(reportOnBodyStatsFunc)); + } + } + + private: + std::unique_ptr eventBase_; + std::unique_ptr eventBaseThread_; +}; + +folly::SemiFuture> sendGet( + http::HttpClient* client, + const std::string& url, + const uint64_t sendDelay = 0, + const std::string body = "") { + return http::RequestBuilder() + .method(proxygen::HTTPMethod::GET) + .url(url) + .send(client, body, sendDelay); +} + +static std::unique_ptr getServer(bool useHttps) { + if (useHttps) { + std::string certPath = getCertsPath("test_cert1.pem"); + std::string keyPath = getCertsPath("test_key1.pem"); + std::string ciphers = "AES128-SHA,AES128-SHA256,AES256-GCM-SHA384"; + auto httpsConfig = std::make_unique( + folly::SocketAddress("127.0.0.1", 0), certPath, keyPath, ciphers); + return std::make_unique(nullptr, std::move(httpsConfig)); + } else { + return std::make_unique( + std::make_unique( + folly::SocketAddress("127.0.0.1", 0))); + } +} + +struct StringPromise { + explicit StringPromise(folly::Promise p) + : promise(std::move(p)) {} + folly::Promise promise; +}; + +enum RequestStatus { kStatusUnknown, kStatusInvalid, kStatusValid }; + +struct AsyncMsgRequestState { + folly::Promise requestPromise; + uint64_t maxWaitMillis{0}; + std::weak_ptr msgPromise; + RequestStatus requestStatus{kStatusUnknown}; +}; + +http::EndpointRequestHandlerFactory asyncMsg( + std::shared_ptr request) { + return [request]( + proxygen::HTTPMessage* /* message */, + const std::vector& /* args */) { + return new http::CallbackRequestHandler( + [request]( + proxygen::HTTPMessage* /*message*/, + const std::vector>& /*body*/, + proxygen::ResponseHandler* downstream, + std::shared_ptr handlerState) { + auto [promise, future] = folly::makePromiseContract(); + auto eventBase = folly::EventBaseManager::get()->getEventBase(); + auto maxWaitMillis = request->maxWaitMillis; + if (maxWaitMillis == 0) { + maxWaitMillis = 1'000'000'000; + } + + std::move(future) + .via(eventBase) + .onTimeout( + std::chrono::milliseconds(maxWaitMillis), + []() { return std::string("Timedout"); }) + .thenValue([downstream, handlerState, request](std::string msg) { + if (!handlerState->requestExpired()) { + request->requestStatus = kStatusValid; + proxygen::ResponseBuilder(downstream) + .status(http::kHttpOk, "") + .header(proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") + .body(msg) + .sendWithEOM(); + } else { + request->requestStatus = kStatusInvalid; + } + }) + .thenError( + folly::tag_t{}, + [downstream, handlerState, request](std::exception const& e) { + if (!handlerState->requestExpired()) { + request->requestStatus = kStatusValid; + proxygen::ResponseBuilder(downstream) + .status(http::kHttpInternalServerError, "") + .header( + proxygen::HTTP_HEADER_CONTENT_TYPE, "text/plain") + .body(e.what()) + .sendWithEOM(); + } else { + request->requestStatus = kStatusInvalid; + } + }); + auto promiseHolder = + std::make_shared(std::move(promise)); + handlerState->runOnFinalization( + [promiseHolder]() mutable { promiseHolder.reset(); }); + request->msgPromise = folly::to_weak_ptr(promiseHolder); + request->requestPromise.setValue(true); + }); + }; +} +} // namespace diff --git a/presto-native-execution/scripts/setup-adapters.sh b/presto-native-execution/scripts/setup-adapters.sh new file mode 100755 index 0000000000000..e40080dd589d7 --- /dev/null +++ b/presto-native-execution/scripts/setup-adapters.sh @@ -0,0 +1,60 @@ +#!/bin/bash +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +# Propagate errors and improve debugging. +set -eufx -o pipefail + +SCRIPT_DIR=$(readlink -f "$(dirname "${BASH_SOURCE[0]}")") +if [ -f "${SCRIPT_DIR}/setup-helper-functions.sh" ] +then + source "${SCRIPT_DIR}/setup-helper-functions.sh" +else + source "${SCRIPT_DIR}/../velox/scripts/setup-helper-functions.sh" +fi +DEPENDENCY_DIR=${DEPENDENCY_DIR:-$(pwd)} + +function install_jwt_cpp { + github_checkout Thalhammer/jwt-cpp v0.6.0 --depth 1 + cmake_install -DBUILD_TESTS=OFF -DJWT_BUILD_EXAMPLES=OFF -DJWT_DISABLE_PICOJSON=ON -DJWT_CMAKE_FILES_INSTALL_DIR="${DEPENDENCY_DIR}/jwt-cpp" +} + +cd "${DEPENDENCY_DIR}" || exit + +install_jwt=0 + +if [ "$#" -eq 0 ]; then + # Install all adapters by default + install_jwt=1 +fi + +while [[ $# -gt 0 ]]; do + case $1 in + jwt) + install_jwt=1 + shift # past argument + ;; + *) + echo "ERROR: Unknown option $1! will be ignored!" + shift + ;; + esac +done + +if [ $install_jwt -eq 1 ]; then + install_jwt_cpp +fi + +_ret=$? +if [ $_ret -eq 0 ] ; then + echo "All deps for Presto adapters installed!" +fi