diff --git a/CMakeLists.txt b/CMakeLists.txt index 4ee98897b6f4..121dd463961d 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -17,6 +17,7 @@ if (LEGACY_BUILD) cmake_policy(SET CMP0077 OLD) # CMP0077: option() honors normal variables. Introduced in 3.13 endif () + get_filename_component(AWS_NATIVE_SDK_ROOT "${CMAKE_CURRENT_SOURCE_DIR}" ABSOLUTE) # Cmake invocation variables: @@ -33,6 +34,7 @@ if (LEGACY_BUILD) option(BUILD_SHARED_LIBS "If enabled, all aws sdk libraries will be build as shared objects; otherwise all Aws libraries will be built as static objects" ON) option(FORCE_SHARED_CRT "If enabled, will unconditionally link the standard libraries in dynamically, otherwise the standard library will be linked in based on the BUILD_SHARED_LIBS setting" ON) option(SIMPLE_INSTALL "If enabled, removes all the additional indirection (platform/cpu/config) in the bin and lib directories on the install step" ON) + option(USE_CRT_HTTP_CLIENT "If enabled, The common runtime HTTP client will be used, and the legacy systems such as WinHttp and libcurl will not be built or included" OFF) option(NO_HTTP_CLIENT "If enabled, no platform-default http client will be included in the library. For the library to be used you will need to provide your own platform-specific implementation" OFF) option(NO_ENCRYPTION "If enabled, no platform-default encryption will be included in the library. For the library to be used you will need to provide your own platform-specific implementations" OFF) option(USE_IXML_HTTP_REQUEST_2 "If enabled on windows, the com object IXmlHttpRequest2 will be used for the http stack" OFF) diff --git a/cmake/external_dependencies.cmake b/cmake/external_dependencies.cmake index 20feb6de51bf..737ecf16c293 100644 --- a/cmake/external_dependencies.cmake +++ b/cmake/external_dependencies.cmake @@ -55,7 +55,7 @@ elseif(ENABLE_INJECTED_ENCRYPTION) endif() # Http client control -if(NOT NO_HTTP_CLIENT) +if(NOT NO_HTTP_CLIENT AND NOT USE_CRT_HTTP_CLIENT) if(PLATFORM_WINDOWS) if(FORCE_CURL) set(ENABLE_CURL_CLIENT 1) @@ -114,6 +114,8 @@ if(NOT NO_HTTP_CLIENT) else() message(FATAL_ERROR "No http client available for target platform and client injection not enabled (-DNO_HTTP_CLIENT=ON)") endif() +elseif(USE_CRT_HTTP_CLIENT) + add_definitions("-DAWS_SDK_USE_CRT_HTTP -DHAVE_H2_CLIENT") else() message(STATUS "You will need to inject an http client implementation before making any http requests!") endif() diff --git a/src/aws-cpp-sdk-core/CMakeLists.txt b/src/aws-cpp-sdk-core/CMakeLists.txt index ab20cec84836..c2a4466c830d 100644 --- a/src/aws-cpp-sdk-core/CMakeLists.txt +++ b/src/aws-cpp-sdk-core/CMakeLists.txt @@ -197,8 +197,12 @@ elseif(ENABLE_WINDOWS_CLIENT) unset(CMAKE_REQUIRED_LIBRARIES) endif() +elseif(USE_CRT_HTTP_CLIENT) + file(GLOB CRT_HTTP_HEADERS "include/aws/core/http/crt/*.h") + file(GLOB CRT_HTTP_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/source/http/crt/*.cpp") endif() + if (PLATFORM_WINDOWS) file(GLOB NET_SOURCE "${CMAKE_CURRENT_SOURCE_DIR}/source/net/windows/*.cpp") elseif(PLATFORM_LINUX OR PLATFORM_APPLE OR PLATFORM_ANDROID) @@ -229,6 +233,7 @@ file(GLOB AWS_NATIVE_SDK_COMMON_SRC ${AWS_CLIENT_SOURCE} ${HTTP_STANDARD_SOURCE} ${HTTP_CLIENT_SOURCE} + ${CRT_HTTP_SOURCE} ${CONFIG_SOURCE} ${CONFIG_DEFAULTS_SOURCE} ${ENDPOINT_SOURCE} @@ -401,6 +406,8 @@ if(MSVC) elseif(ENABLE_WINDOWS_CLIENT) source_group("Header Files\\aws\\core\\http\\windows" FILES ${HTTP_WINDOWS_CLIENT_HEADERS}) endif() + source_group("Header Files\\aws\\core\\http\\crt" FILES ${CRT_HTTP_HEADERS}) + # encryption conditional headers if(ENABLE_BCRYPT_ENCRYPTION) @@ -448,8 +455,11 @@ if(MSVC) source_group("Source Files\\http\\curl" FILES ${HTTP_CURL_CLIENT_SOURCE}) elseif(ENABLE_WINDOWS_CLIENT) source_group("Source Files\\http\\windows" FILES ${HTTP_WINDOWS_CLIENT_SOURCE}) + elseif(USE_CRT_HTTP_CLIENT) + source_group("Source Files\\http\\crt" FILES ${CRT_HTTP_SOURCE}) endif() + # encryption conditional source if(ENABLE_BCRYPT_ENCRYPTION) source_group("Source Files\\utils\\crypto\\bcrypt" FILES ${UTILS_CRYPTO_BCRYPT_SOURCE}) @@ -627,8 +637,11 @@ if(ENABLE_CURL_CLIENT) install (FILES ${HTTP_CURL_CLIENT_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/aws/core/http/curl) elseif(ENABLE_WINDOWS_CLIENT) install (FILES ${HTTP_WINDOWS_CLIENT_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/aws/core/http/windows) +elseif(USE_CRT_HTTP_CLIENT) + install (FILES ${CRT_HTTP_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/aws/core/http/crt) endif() + # encryption headers if(ENABLE_BCRYPT_ENCRYPTION) install (FILES ${UTILS_CRYPTO_BCRYPT_HEADERS} DESTINATION ${INCLUDE_DIRECTORY}/aws/core/utils/crypto/bcrypt) diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h index 4c292064d73d..cb6e928e768e 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpClient.h @@ -67,10 +67,17 @@ namespace Aws bool ContinueRequest(const Aws::Http::HttpRequest&) const; + explicit operator bool() const + { + return !m_bad; + } + + protected: + bool m_bad; + private: std::atomic< bool > m_disableRequestProcessing; - std::mutex m_requestProcessingSignalLock; std::condition_variable m_requestProcessingSignal; }; diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h index 129bd3bd3643..33e5d072bdce 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpRequest.h @@ -552,7 +552,7 @@ namespace Aws { m_requestHash = std::make_pair(algorithmName, hash); } - const std::pair>& GetRequestHash() { return m_requestHash; } + const std::pair>& GetRequestHash() const { return m_requestHash; } void AddResponseValidationHash(const Aws::String& algorithmName, const std::shared_ptr& hash) { diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/HttpResponse.h b/src/aws-cpp-sdk-core/include/aws/core/http/HttpResponse.h index 8082547a7175..5979660f742b 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/HttpResponse.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/HttpResponse.h @@ -195,6 +195,14 @@ namespace Aws * Adds a header to the http response object. */ virtual void AddHeader(const Aws::String&, const Aws::String&) = 0; + /** + * Add a header to the http response object, and move the value. + * The name can't be moved as it is converted to lower-case. + * + * It isn't pure virtual for backwards compatiblity reasons, but the StandardHttpResponse used by default in the SDK + * implements the move. + */ + virtual void AddHeader(const Aws::String& headerName, Aws::String&& headerValue) { AddHeader(headerName, headerValue); }; /** * Sets the content type header on the http response object. */ diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h new file mode 100644 index 000000000000..a133273743a3 --- /dev/null +++ b/src/aws-cpp-sdk-core/include/aws/core/http/crt/CRTHttpClient.h @@ -0,0 +1,75 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include + +#include +#include +#include + +#include +#include + +namespace Aws +{ + namespace Crt + { + namespace Http + { + class HttpClientConnectionManager; + class HttpClientConnectionOptions; + } + + namespace Io + { + class ClientBootstrap; + } + } + + namespace Client + { + struct ClientConfiguration; + } // namespace Client + + namespace Http + { + /** + * Common Runtime implementation of AWS SDK for C++ HttpClient interface. + */ + class AWS_CORE_API CrtHttpClient : public HttpClient { + public: + using Base = HttpClient; + + /** + * Initializes the client with relevant parameters from clientConfig. + */ + CrtHttpClient(const Aws::Client::ClientConfiguration& clientConfig, Crt::Io::ClientBootstrap& bootstrap); + ~CrtHttpClient() override; + + std::shared_ptr MakeRequest(const std::shared_ptr& request, + Aws::Utils::RateLimits::RateLimiterInterface* readLimiter, + Aws::Utils::RateLimits::RateLimiterInterface* writeLimiter) const override; + + private: + // Yeah, I know, but someone made MakeRequest() const and didn't think about the fact that + // making an HTTP request most certainly mutates state. It was me. I'm the person that did that, and + // now we're stuck with it. Thanks me. + mutable std::unordered_map> m_connectionPools; + mutable std::mutex m_connectionPoolLock; + + Crt::Io::TlsContext m_context; + Crt::Optional m_proxyOptions; + + Crt::Io::ClientBootstrap& m_bootstrap; + Client::ClientConfiguration m_configuration; + + std::shared_ptr GetWithCreateConnectionManagerForRequest(const std::shared_ptr& request, const Crt::Http::HttpClientConnectionOptions& connectionOptions) const; + Crt::Http::HttpClientConnectionOptions CreateConnectionOptionsForRequest(const std::shared_ptr& request) const; + void CheckAndInitializeProxySettings(); + + static Aws::String ResolveConnectionPoolKey(const URI& uri); + }; + } +} diff --git a/src/aws-cpp-sdk-core/include/aws/core/http/standard/StandardHttpResponse.h b/src/aws-cpp-sdk-core/include/aws/core/http/standard/StandardHttpResponse.h index 309206f1f345..a3c5a20bdc73 100644 --- a/src/aws-cpp-sdk-core/include/aws/core/http/standard/StandardHttpResponse.h +++ b/src/aws-cpp-sdk-core/include/aws/core/http/standard/StandardHttpResponse.h @@ -37,28 +37,33 @@ namespace Aws /** * Get the headers from this response */ - HeaderValueCollection GetHeaders() const; + HeaderValueCollection GetHeaders() const override; /** * Returns true if the response contains a header by headerName */ - bool HasHeader(const char* headerName) const; + bool HasHeader(const char* headerName) const override; /** * Returns the value for a header at headerName if it exists. */ - const Aws::String& GetHeader(const Aws::String&) const; + const Aws::String& GetHeader(const Aws::String&) const override; /** * Gets the response body of the response. */ - inline Aws::IOStream& GetResponseBody() const { return bodyStream.GetUnderlyingStream(); } + inline Aws::IOStream& GetResponseBody() const override { return bodyStream.GetUnderlyingStream(); } /** * Gives full control of the memory of the ResponseBody over to the caller. At this point, it is the caller's * responsibility to clean up this object. */ - inline Utils::Stream::ResponseStream&& SwapResponseStreamOwnership() { return std::move(bodyStream); } + inline Utils::Stream::ResponseStream&& SwapResponseStreamOwnership() override { return std::move(bodyStream); } /** * Adds a header to the http response object. */ - void AddHeader(const Aws::String&, const Aws::String&); + void AddHeader(const Aws::String&, const Aws::String&) override; + /** + * Add a header to the http response object, and move the value. + * The name can't be moved as it is converted to lower-case. + */ + void AddHeader(const Aws::String& headerName, Aws::String&& headerValue) override; private: StandardHttpResponse(const StandardHttpResponse&); diff --git a/src/aws-cpp-sdk-core/source/Aws.cpp b/src/aws-cpp-sdk-core/source/Aws.cpp index 4fd97618f308..45364b24ab67 100644 --- a/src/aws-cpp-sdk-core/source/Aws.cpp +++ b/src/aws-cpp-sdk-core/source/Aws.cpp @@ -16,6 +16,8 @@ #include #include +#include + namespace Aws { static const char* ALLOCATION_TAG = "Aws_Init_Cleanup"; diff --git a/src/aws-cpp-sdk-core/source/http/HttpClient.cpp b/src/aws-cpp-sdk-core/source/http/HttpClient.cpp index 854202339362..a90c5e94e4a1 100644 --- a/src/aws-cpp-sdk-core/source/http/HttpClient.cpp +++ b/src/aws-cpp-sdk-core/source/http/HttpClient.cpp @@ -10,6 +10,7 @@ using namespace Aws; using namespace Aws::Http; HttpClient::HttpClient() : + m_bad(false), m_disableRequestProcessing( false ), m_requestProcessingSignalLock(), m_requestProcessingSignal() diff --git a/src/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp b/src/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp index da1759ef9776..10616afedb05 100644 --- a/src/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp +++ b/src/aws-cpp-sdk-core/source/http/HttpClientFactory.cpp @@ -5,6 +5,10 @@ #include +#if AWS_SDK_USE_CRT_HTTP +#include +#include +#endif #if ENABLE_CURL_CLIENT #include #include @@ -62,10 +66,12 @@ namespace Aws { std::shared_ptr CreateHttpClient(const ClientConfiguration& clientConfiguration) const override { +#if AWS_SDK_USE_CRT_HTTP + return Aws::MakeShared(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, clientConfiguration, *GetDefaultClientBootstrap()); // Figure out whether the selected option is available but fail gracefully and return a default of some type if not // Windows clients: Http and Inet are always options, Curl MIGHT be an option if USE_CURL_CLIENT is on, and http is "default" // Other clients: Curl is your default -#if ENABLE_WINDOWS_CLIENT +#elif ENABLE_WINDOWS_CLIENT #if ENABLE_WINDOWS_IXML_HTTP_REQUEST_2_CLIENT #if BYPASS_DEFAULT_PROXY switch (clientConfiguration.httpLibOverride) @@ -189,7 +195,17 @@ namespace Aws std::shared_ptr CreateHttpClient(const Aws::Client::ClientConfiguration& clientConfiguration) { assert(GetHttpClientFactory()); - return GetHttpClientFactory()->CreateHttpClient(clientConfiguration); + auto client = GetHttpClientFactory()->CreateHttpClient(clientConfiguration); + + if (!client) + { + AWS_LOGSTREAM_FATAL(HTTP_CLIENT_FACTORY_ALLOCATION_TAG, "Initializing Http Client failed!"); + // assert just in case this is a misconfiguration at development time to make the dev's job easier. + assert(false && "Http client initialization failed. Some client configuration parameters are probably invalid"); + std::abort(); + } + + return client; } std::shared_ptr CreateHttpRequest(const Aws::String& uri, HttpMethod method, const Aws::IOStreamFactory& streamFactory) diff --git a/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp new file mode 100644 index 000000000000..b804d6747cbe --- /dev/null +++ b/src/aws-cpp-sdk-core/source/http/crt/CRTHttpClient.cpp @@ -0,0 +1,697 @@ +/** + * Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. + * SPDX-License-Identifier: Apache-2.0. + */ + +#include +#include +#include +#include + +#include +#include + +namespace Aws +{ + namespace Http + { + static const char *const CRT_HTTP_CLIENT_TAG = "CrtHttpClient"; + + // Adapts AWS SDK input streams and rate limiters to the Crt input stream reading model. + class SDKAdaptingInputStream : public Crt::Io::StdIOStreamInputStream + { + public: + SDKAdaptingInputStream(Utils::RateLimits::RateLimiterInterface* rateLimiter, std::shared_ptr stream, + const Http::HttpClient& client, const Http::HttpRequest& request, + Aws::Crt::Allocator *allocator = Crt::ApiAllocator()) noexcept : + Crt::Io::StdIOStreamInputStream(std::move(stream), allocator), m_rateLimiter(rateLimiter), + m_client(client), m_currentRequest(request), m_chunkEnd(false) + { + m_isAwsChunked = m_currentRequest.HasHeader(Aws::Http::CONTENT_ENCODING_HEADER) && + m_currentRequest.GetHeaderValue(Aws::Http::CONTENT_ENCODING_HEADER) == Aws::Http::AWS_CHUNKED_VALUE; + } + protected: + + bool ReadImpl(Crt::ByteBuf &buffer) noexcept override + { + if (!m_client.ContinueRequest(m_currentRequest) || !m_client.IsRequestProcessingEnabled()) + { + return false; + } + + size_t amountToRead = buffer.capacity - buffer.len; + uint8_t* originalBufferPos = buffer.buffer; + + // aws-chunk = hex(chunk-size) + CRLF + chunk-data + CRLF + // Needs to reserve bytes of sizeof(hex(chunk-size)) + sizeof(CRLF) + sizeof(CRLF) + if (m_isAwsChunked) + { + Aws::String amountToReadHexString = Aws::Utils::StringUtils::ToHexString(amountToRead); + + auto expansionSpace = amountToReadHexString.size() + 4; + if (amountToRead < expansionSpace) + { + // if we don't have enough left to read into go ahead and bail. We can handle it next time. + return true; + } + + amountToRead -= expansionSpace; + } + + // initial check to see if we should avoid reading for the moment. + if (!m_rateLimiter || m_rateLimiter->ApplyCost(0) == std::chrono::milliseconds(0)) { + size_t currentPos = buffer.len; + + // now do the read. We may over read by an IO buffer size, but it's fine. The throttle will still + // kick-in in plenty of time. + if (!Crt::Io::StdIOStreamInputStream::ReadImpl(buffer)) + { + return false; + } + + size_t newPos = buffer.len; + assert(newPos >= currentPos && "the buffer length should not have decreased in value."); + size_t amountRead = newPos - currentPos; + + if (m_isAwsChunked) + { + // if we have a chunk to wrap, wrap it, be sure to update the running checksum. + if (amountRead > 0) + { + if (m_currentRequest.GetRequestHash().second != nullptr) + { + m_currentRequest.GetRequestHash().second->Update(reinterpret_cast(originalBufferPos), amountRead); + } + + Aws::String hex = Aws::Utils::StringUtils::ToHexString(amountRead); + // this is safe because of the isAwsChunked branch above. + // I don't see a aws_byte_buf equivalent of memmove. This is lifted from the curl implementation. + memmove(originalBufferPos + hex.size() + 2, originalBufferPos, amountRead); + memmove(originalBufferPos + hex.size() + 2 + amountRead, "\r\n", 2); + memmove(originalBufferPos, hex.c_str(), hex.size()); + memmove(originalBufferPos + hex.size(), "\r\n", 2); + amountRead += hex.size() + 4; + } + else if (!m_chunkEnd) + { + auto status = GetStatusImpl(); + if (!status.is_end_of_stream) + { + // if we didn't read anything, then lets finish up the chunk and send it. + // the reference implementation seems to assume only one chunk is allowed, + // because the chunkEnd bit is never updated keep that same behavior here. + Aws::StringStream chunkedTrailer; + chunkedTrailer << "0\r\n"; + if (m_currentRequest.GetRequestHash().second != nullptr) + { + chunkedTrailer << "x-amz-checksum-" << m_currentRequest.GetRequestHash().first + << ":" + << HashingUtils::Base64Encode( + m_currentRequest.GetRequestHash().second->GetHash().GetResult()) + << "\r\n"; + } + chunkedTrailer << "\r\n"; + amountRead = chunkedTrailer.str().size(); + memcpy(originalBufferPos, chunkedTrailer.str().c_str(), amountRead); + m_chunkEnd = true; + } + } + buffer.len += amountRead; + } + + auto& sentHandler = m_currentRequest.GetDataSentEventHandler(); + if (sentHandler) + { + sentHandler(&m_currentRequest, static_cast(amountRead)); + } + + if (m_rateLimiter) + { + // now actually reduce the window. + m_rateLimiter->ApplyCost(static_cast(newPos - currentPos)); + } + } + + return true; + } + + private: + Utils::RateLimits::RateLimiterInterface* m_rateLimiter; + const Http::HttpClient& m_client; + const Http::HttpRequest& m_currentRequest; + bool m_chunkEnd; + bool m_isAwsChunked; + }; + + // Just a wrapper around a Condition Variable and a mutex, which handles wait and timed waits while protecting + // from spurious wakeups. + class AsyncWaiter + { + public: + AsyncWaiter() = default; + AsyncWaiter(const AsyncWaiter&) = delete; + AsyncWaiter& operator=(const AsyncWaiter&) = delete; + + void Wakeup() + { + { + std::lock_guard locker(m_lock); + m_wakeupIntentional = true; + } + m_cvar.notify_one(); + } + + void WaitOnCompletion() + { + std::unique_lock uniqueLocker(m_lock); + m_cvar.wait(uniqueLocker, [this](){return m_wakeupIntentional;}); + } + + bool WaitOnCompletionUntil(std::chrono::time_point until) + { + std::unique_lock uniqueLocker(m_lock); + return m_cvar.wait_until(uniqueLocker, until, [this](){return m_wakeupIntentional;}); + } + + private: + std::mutex m_lock; + std::condition_variable m_cvar; + bool m_wakeupIntentional{false}; + }; + + CrtHttpClient::CrtHttpClient(const Aws::Client::ClientConfiguration& clientConfig, Crt::Io::ClientBootstrap& bootstrap) : + HttpClient(), m_context(), m_proxyOptions(), m_bootstrap(bootstrap), m_configuration(clientConfig) + { + //first need to figure TLS out... + Crt::Io::TlsContextOptions tlsContextOptions = Crt::Io::TlsContextOptions::InitDefaultClient(); + if (!tlsContextOptions) + { + m_bad = true; + return; + } + + CheckAndInitializeProxySettings(); + // the previous function can fail to setup the proxy options correctly. + // if that happened the bad bit has been set, early exit here in that case. + if (m_bad) + { + return; + } + + // Given current SDK configuration assumptions, if the ca is overridden and a proxy is configured, + // it's intended for the proxy, not this context. + if (!m_proxyOptions.has_value()) + { + if (!m_configuration.caPath.empty() || !m_configuration.caFile.empty()) + { + const char* caPath = m_configuration.caPath.empty() ? nullptr : m_configuration.caPath.c_str(); + const char* caFile = m_configuration.caFile.empty() ? nullptr : m_configuration.caFile.c_str(); + if (!tlsContextOptions.OverrideDefaultTrustStore(caPath, caFile)) + { + m_bad = true; + return; + } + } + } + + tlsContextOptions.SetVerifyPeer(m_configuration.verifySSL); + + if (Crt::Io::TlsContextOptions::IsAlpnSupported()) + { + // this may need to be pulled from the client configuration.... + if (!tlsContextOptions.SetAlpnList("h2;http/1.1")) + { + m_bad = true; + return; + } + } + + Crt::Io::TlsContext newContext(tlsContextOptions, Crt::Io::TlsMode::CLIENT); + + if (!newContext) + { + m_bad = true; + return; + } + + m_context = std::move(newContext); + } + + // this isn't entirely necessary, but if you want to be nice to debuggers and memory checkers, let's go ahead + // and shut everything down cleanly. + CrtHttpClient::~CrtHttpClient() + { + Aws::Vector> shutdownFutures; + + for (auto& managerPair : m_connectionPools) + { + shutdownFutures.push_back(managerPair.second->InitiateShutdown()); + } + + for (auto& shutdownFuture : shutdownFutures) + { + shutdownFuture.get(); + } + + shutdownFutures.clear(); + m_connectionPools.clear(); + } + + static void AddRequestMetadataToCrtRequest(const std::shared_ptr& request, const std::shared_ptr& crtRequest) + { + const char* methodStr = Aws::Http::HttpMethodMapper::GetNameForHttpMethod(request->GetMethod()); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Making " << methodStr << " request to " << request->GetURIString()); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Including headers:"); + //Add http headers to the request. + for (const auto& header : request->GetHeaders()) + { + Crt::Http::HttpHeader crtHeader; + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, header.first << ": " << header.second); + crtHeader.name = Crt::ByteCursorFromArray((const uint8_t *)header.first.data(), header.first.length()); + crtHeader.value = Crt::ByteCursorFromArray((const uint8_t *)header.second.data(), header.second.length()); + crtRequest->AddHeader(crtHeader); + } + + // HTTP method, GET, PUT, DELETE, etc... + auto methodCursor = Crt::ByteCursorFromCString(methodStr); + crtRequest->SetMethod(methodCursor); + + // Path portion of the request + auto pathStrCpy = request->GetUri().GetURLEncodedPathRFC3986(); + auto queryStrCpy = request->GetUri().GetQueryString(); + Aws::StringStream ss; + + //Crt client has you pass the query string as part of the path. concatenate that here. + ss << pathStrCpy << queryStrCpy; + auto fullPathAndQueryCpy = ss.str(); + auto pathCursor = Crt::ByteCursorFromArray((uint8_t *)fullPathAndQueryCpy.c_str(), fullPathAndQueryCpy.length()); + crtRequest->SetPath(pathCursor); + } + + static void OnResponseBodyReceived(Crt::Http::HttpStream& stream, const Crt::ByteCursor& body, const std::shared_ptr& response, const std::shared_ptr& request, const Http::HttpClient& client) + { + if (!client.ContinueRequest(*request) || !client.IsRequestProcessingEnabled()) + { + AWS_LOGSTREAM_INFO(CRT_HTTP_CLIENT_TAG, "Request canceled. Canceling request by closing the connection."); + stream.GetConnection().Close(); + return; + } + + //TODO: handle the read rate limiter here, once backpressure is setup. + for (const auto& hashIterator : request->GetResponseValidationHashes()) + { + hashIterator.second->Update(reinterpret_cast(body.ptr), body.len); + } + + // When data is received from the content body of the incoming response, just copy it to the output stream. + response->GetResponseBody().write((const char*)body.ptr, static_cast(body.len)); + + if (request->IsEventStreamRequest() && !response->HasHeader(Aws::Http::X_AMZN_ERROR_TYPE)) + { + response->GetResponseBody().flush(); + } + + auto& receivedHandler = request->GetDataReceivedEventHandler(); + if (receivedHandler) + { + receivedHandler(request.get(), response.get(), static_cast(body.len)); + } + + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, body.len << " bytes written to response."); + + } + + // on response headers arriving, write them to the response. + static void OnIncomingHeaders(Crt::Http::HttpStream&, enum aws_http_header_block block, const Crt::Http::HttpHeader* headersArray, std::size_t headersCount, const std::shared_ptr& response) + { + if (block == AWS_HTTP_HEADER_BLOCK_INFORMATIONAL) return; + + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received Headers: "); + + for (size_t i = 0; i < headersCount; ++i) + { + const Crt::Http::HttpHeader* header = &headersArray[i]; + Aws::String headerNameStr((const char* const)header->name.ptr, header->name.len); + Aws::String headerValueStr((const char* const)header->value.ptr, header->value.len); + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, headerNameStr << ": " << headerValueStr); + response->AddHeader(headerNameStr, std::move(headerValueStr)); + } + } + + static void OnIncomingHeadersBlockDone(Crt::Http::HttpStream& stream, enum aws_http_header_block, const std::shared_ptr& response) + { + AWS_LOGSTREAM_TRACE(CRT_HTTP_CLIENT_TAG, "Received response code: " << stream.GetResponseStatusCode()); + response->SetResponseCode((HttpResponseCode)stream.GetResponseStatusCode()); + } + + // Request is done. If there was an error set it, otherwise just wake up the cvar. + static void OnStreamComplete(Crt::Http::HttpStream&, int errorCode, AsyncWaiter& waiter, const std::shared_ptr& request, const std::shared_ptr& response, const HttpClient& client) + { + if (errorCode) + { + if (!client.IsRequestProcessingEnabled() || !client.ContinueRequest(*request)) + { + response->SetClientErrorType(Aws::Client::CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request cancelled by user"); + } + else + { + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(aws_error_debug_str(errorCode)); + } + } + + waiter.Wakeup(); + } + + // if the connection acquisition failed, go ahead and fail the request and wakeup the cvar. + // If it succeeded go ahead and make the request. + static void OnClientConnectionAvailable(std::shared_ptr connection, int errorCode, std::shared_ptr& connectionReference, + Crt::Http::HttpRequestOptions& requestOptions, AsyncWaiter& waiter, const std::shared_ptr& request, + const std::shared_ptr& response, const HttpClient& client) + { + bool shouldContinueRequest = client.ContinueRequest(*request) && client.IsRequestProcessingEnabled(); + + if (!shouldContinueRequest) + { + response->SetClientErrorType(CoreErrors::USER_CANCELLED); + response->SetClientErrorMessage("Request cancelled by user's continuation handler"); + waiter.Wakeup(); + return; + } + + int finalErrorCode = errorCode; + if (connection) + { + AWS_LOGSTREAM_DEBUG(CRT_HTTP_CLIENT_TAG, "Obtained connection handle " << (void*)connection.get()); + + auto clientStream = connection->NewClientStream(requestOptions); + connectionReference = connection; + + if (clientStream && clientStream->Activate()) { + return; + } + + finalErrorCode = aws_last_error(); + AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Initiation of request failed because " << aws_error_debug_str(finalErrorCode)); + + } + + const char *errorMsg = aws_error_debug_str(finalErrorCode); + AWS_LOGSTREAM_ERROR(CRT_HTTP_CLIENT_TAG, "Obtaining connection failed because " << errorMsg); + response->SetClientErrorType(Aws::Client::CoreErrors::NETWORK_CONNECTION); + response->SetClientErrorMessage(errorMsg); + + waiter.Wakeup(); + } + + std::shared_ptr CrtHttpClient::MakeRequest(const std::shared_ptr& request, + Aws::Utils::RateLimits::RateLimiterInterface*, + Aws::Utils::RateLimits::RateLimiterInterface*) const + { + if (m_bad) + { + return nullptr; + } + + auto crtRequest = Crt::MakeShared(Crt::g_allocator); + auto response = Aws::MakeShared(CRT_HTTP_CLIENT_TAG, request); + + auto requestConnOptions = CreateConnectionOptionsForRequest(request); + auto connectionManager = GetWithCreateConnectionManagerForRequest(request, requestConnOptions); + + if (!connectionManager) + { + response->SetClientErrorMessage(aws_error_debug_str(aws_last_error())); + response->SetClientErrorType(CoreErrors::INVALID_PARAMETER_COMBINATION); + return response; + } + AddRequestMetadataToCrtRequest(request, crtRequest); + + // Set the request body stream on the crt request. Setup the write rate limiter if present + if (request->GetContentBody()) + { + crtRequest->SetBody(Aws::MakeShared(CRT_HTTP_CLIENT_TAG, m_configuration.writeRateLimiter.get(), request->GetContentBody(), *this, *request)); + } + + Crt::Http::HttpRequestOptions requestOptions; + requestOptions.request = crtRequest.get(); + + requestOptions.onIncomingBody = + [this, request, response](Crt::Http::HttpStream& stream, const Crt::ByteCursor& body) + { + OnResponseBodyReceived(stream, body, response, request, *this); + }; + + requestOptions.onIncomingHeaders = + [response](Crt::Http::HttpStream& stream, enum aws_http_header_block block, const Crt::Http::HttpHeader* headersArray, std::size_t headersCount) + { + OnIncomingHeaders(stream, block, headersArray, headersCount, response); + }; + + // This will arrive at or around the same time as the headers. Use it to set the response code on the response + requestOptions.onIncomingHeadersBlockDone = + [response](Crt::Http::HttpStream& stream, enum aws_http_header_block block) + { + OnIncomingHeadersBlockDone(stream, block, response); + }; + + // Crt client is async only so we'll need to do the synchronous part ourselves. + // We'll use a condition variable and wait on it until the request completes or errors out. + AsyncWaiter waiter; + + requestOptions.onStreamComplete = + [&waiter, &request, this, &response](Crt::Http::HttpStream& stream, int errorCode) + { + OnStreamComplete(stream, errorCode, waiter, request, response, *this); + }; + + std::shared_ptr connectionRef(nullptr); + + // now we finally have the request, get a connection and make the request. + connectionManager->AcquireConnection( + [&connectionRef, &requestOptions, response, &waiter, request, this] + (std::shared_ptr connection, int errorCode) + { + OnClientConnectionAvailable(connection, errorCode, connectionRef, requestOptions, waiter, request, response, *this); + }); + + bool waiterTimedOut = false; + // Naive http request timeout implementation. This doesn't factor in how long it took to get the connection from the pool, and + // I'm undecided on the queueing theory implications of this decision so if this turns out to be the wrong granularity + // this is the section of code you should be changing. You can probably get "close" by having an additional + // atomic (not necessarily full on atomics implementation, but it needs to be the size of a WORD if it's not) + // counter that gets incremented in the acquireConnection callback as long as your connection timeout + // is shorter than your request timeout. Even if it's not, that would handle like.... 4-5 nines of getting this right. + // since in the worst case scenario, your connect timeout got preempted by the request timeout, and is it really worth + // all that effort if that's the worst thing that can happen? + if (m_configuration.requestTimeoutMs > 0 ) + { + auto requestExpiryTime = std::chrono::high_resolution_clock::now() + + std::chrono::milliseconds(m_configuration.requestTimeoutMs); + waiterTimedOut = !waiter.WaitOnCompletionUntil(requestExpiryTime); + + // if this is true, the waiter timed out without a terminal condition being woken up. + if (waiterTimedOut) + { + // close the connection if it's still there so we can expedite anything we're waiting on. + if (connectionRef) + { + connectionRef->Close(); + } + } + } + + // always wait, even if the above section timed out, because Wakeup() hasn't yet been called, + // and this means we're still waiting on some queued up callbacks to fire. + // going past this point before that occurs will cause a segfault when the callback DOES finally fire + // since the waiter is on the stack. + waiter.WaitOnCompletion(); + + // now handle if we timed out or not. + // OnStreamComplete will have set an error by this point if the connection was closed out + // due to a timeout. Check that an error hasn't been set first, because if it hasn't + // the request actually succeeded. + if (waiterTimedOut && response->GetClientErrorType() != Aws::Client::CoreErrors::OK) + { + response->SetClientErrorType( + Aws::Client::CoreErrors::REQUEST_TIMEOUT); + response->SetClientErrorMessage("Request Timeout Has Expired"); + } + + // TODO: is VOX support still a thing? If so we need to add the metrics for it. + return response; + } + + Aws::String CrtHttpClient::ResolveConnectionPoolKey(const URI& uri) + { + // create a unique key for this endpoint. + Aws::StringStream ss; + ss << SchemeMapper::ToString(uri.GetScheme()) << "://" << uri.GetAuthority() << ":" << uri.GetPort(); + + return ss.str(); + } + + // The main purpose of this is to ensure there's exactly one connection manager per unique endpoint. + // To do so, we simply keep a hash table of the endpoint key (see ResolveConnectionPoolKey()), and + // put a connection manager for that endpoint as the value. + // This runs in multiple threads potentially so there's a lock around it. + std::shared_ptr CrtHttpClient::GetWithCreateConnectionManagerForRequest(const std::shared_ptr& request, const Crt::Http::HttpClientConnectionOptions& options) const + { + const auto connManagerRequestKey = ResolveConnectionPoolKey(request->GetUri()); + + std::lock_guard locker(m_connectionPoolLock); + + const auto& foundManager = m_connectionPools.find(connManagerRequestKey); + + // We've already got one, return it. + if (foundManager != m_connectionPools.cend()) { + return foundManager->second; + } + + // don't have a manager for this endpoint, so create one for it. + Crt::Http::HttpClientConnectionManagerOptions connectionManagerOptions; + connectionManagerOptions.ConnectionOptions = options; + connectionManagerOptions.MaxConnections = m_configuration.maxConnections; + connectionManagerOptions.EnableBlockingShutdown = true; + //TODO: need to bind out Monitoring options to handle the read timeout config value. + // once done, come back and use it to setup read timeouts. + + auto connectionManager = Crt::Http::HttpClientConnectionManager::NewClientConnectionManager(connectionManagerOptions); + + if (!connectionManager) + { + return nullptr; + } + + // put it in the hash table and return it. + m_connectionPools.emplace(connManagerRequestKey, connectionManager); + + return connectionManager; + } + + Crt::Http::HttpClientConnectionOptions CrtHttpClient::CreateConnectionOptionsForRequest(const std::shared_ptr& request) const + { + // connection options are unique per request, this is mostly just connection-level configuration mapping. + Crt::Http::HttpClientConnectionOptions connectionOptions; + connectionOptions.HostName = request->GetUri().GetAuthority().c_str(); + // TODO: come back and update this when we hook up the rate limiters. + connectionOptions.ManualWindowManagement = false; + connectionOptions.Port = request->GetUri().GetPort(); + + if (request->GetUri().GetScheme() == Scheme::HTTPS) + { + connectionOptions.TlsOptions = m_context.NewConnectionOptions(); + auto serverName = request->GetUri().GetAuthority(); + auto serverNameCursor = Crt::ByteCursorFromCString(serverName.c_str()); + connectionOptions.TlsOptions->SetServerName(serverNameCursor); + } + + connectionOptions.Bootstrap = &m_bootstrap; + + if (m_proxyOptions.has_value()) + { + connectionOptions.ProxyOptions = m_proxyOptions.value(); + } + + connectionOptions.SocketOptions.SetConnectTimeoutMs(m_configuration.connectTimeoutMs); + connectionOptions.SocketOptions.SetKeepAlive(m_configuration.enableTcpKeepAlive); + + if (m_configuration.enableTcpKeepAlive) + { + connectionOptions.SocketOptions.SetKeepAliveIntervalSec( + (uint16_t) (m_configuration.tcpKeepAliveIntervalMs / 1000)); + } + connectionOptions.SocketOptions.SetSocketType(Crt::Io::SocketType::Stream); + + return connectionOptions; + } + + // The proxy config is pretty hefty, so we don't want to create one for each request when we don't have to. + // This converts whatever proxy settings are in clientConfig to Crt specific proxy settings. + // It then sets it on the member variable for re-use elsewhere. + void CrtHttpClient::CheckAndInitializeProxySettings() + { + if (!m_configuration.proxyHost.empty()) + { + Crt::Http::HttpClientConnectionProxyOptions proxyOptions; + + if (!m_configuration.proxyUserName.empty()) + { + proxyOptions.AuthType = Crt::Http::AwsHttpProxyAuthenticationType::Basic; + proxyOptions.BasicAuthUsername = m_configuration.proxyUserName.c_str(); + proxyOptions.BasicAuthPassword = m_configuration.proxyPassword.c_str(); + } + + proxyOptions.HostName = m_configuration.proxyHost.c_str(); + + if (m_configuration.proxyPort != 0) + { + proxyOptions.Port = static_cast(m_configuration.proxyPort); + } + else + { + proxyOptions.Port = m_configuration.proxyScheme == Scheme::HTTPS ? 443 : 80; + } + + if (m_configuration.proxyScheme == Scheme::HTTPS) + { + Crt::Io::TlsContextOptions contextOptions = Crt::Io::TlsContextOptions::InitDefaultClient(); + + if (!contextOptions) + { + m_bad = true; + return; + } + + if (m_configuration.proxySSLKeyPassword.empty() && !m_configuration.proxySSLCertPath.empty()) + { + const char* certPath = m_configuration.proxySSLCertPath.empty() ? nullptr : m_configuration.proxySSLCertPath.c_str(); + const char* certFile = m_configuration.proxySSLKeyPath.empty() ? nullptr : m_configuration.proxySSLKeyPath.c_str(); + contextOptions = Crt::Io::TlsContextOptions::InitClientWithMtls(certPath, certFile); + if (!contextOptions) + { + m_bad = true; + return; + } + } + else if (!m_configuration.proxySSLKeyPassword.empty()) + { + const char* pkcs12CertFile = m_configuration.proxySSLKeyPath.empty() ? nullptr : m_configuration.proxySSLKeyPath.c_str(); + const char* pkcs12Pwd = m_configuration.proxySSLKeyPassword.c_str(); + contextOptions = Crt::Io::TlsContextOptions::InitClientWithMtlsPkcs12(pkcs12CertFile, pkcs12Pwd); + if (!contextOptions) + { + m_bad = true; + return; + } + } + + if (!m_configuration.caFile.empty() || !m_configuration.caPath.empty()) + { + const char* caPath = m_configuration.caPath.empty() ? nullptr : m_configuration.caPath.c_str(); + const char* caFile = m_configuration.caFile.empty() ? nullptr : m_configuration.caFile.c_str(); + contextOptions.OverrideDefaultTrustStore(caPath, caFile); + if (!contextOptions) + { + m_bad = true; + return; + } + } + + contextOptions.SetVerifyPeer(m_configuration.verifySSL); + Crt::Io::TlsContext context = Crt::Io::TlsContext(contextOptions, Crt::Io::TlsMode::CLIENT); + proxyOptions.TlsOptions = context.NewConnectionOptions(); + if (proxyOptions.TlsOptions) + { + m_bad = true; + return; + } + } + + m_proxyOptions = std::move(proxyOptions); + } + } + + } +} diff --git a/src/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp b/src/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp index 8b62ae5e634d..bf3a9eb5aeea 100644 --- a/src/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp +++ b/src/aws-cpp-sdk-core/source/http/standard/StandardHttpResponse.cpp @@ -21,9 +21,9 @@ HeaderValueCollection StandardHttpResponse::GetHeaders() const { HeaderValueCollection headerValueCollection; - for (Aws::Map::const_iterator iter = headerMap.begin(); iter != headerMap.end(); ++iter) + for (const auto & iter : headerMap) { - headerValueCollection.emplace(HeaderValuePair(iter->first, iter->second)); + headerValueCollection.emplace(HeaderValuePair(iter.first, iter.second)); } return headerValueCollection; @@ -36,11 +36,11 @@ bool StandardHttpResponse::HasHeader(const char* headerName) const const Aws::String& StandardHttpResponse::GetHeader(const Aws::String& headerName) const { - Aws::Map::const_iterator foundValue = headerMap.find(StringUtils::ToLower(headerName.c_str())); + auto foundValue = headerMap.find(StringUtils::ToLower(headerName.c_str())); assert(foundValue != headerMap.end()); if (foundValue == headerMap.end()) { AWS_LOGSTREAM_ERROR(STANDARD_HTTP_RESPONSE_LOG_TAG, "Requested a header value for a missing header key: " << headerName); - static const Aws::String EMPTY_STRING = ""; + static const Aws::String EMPTY_STRING; return EMPTY_STRING; } return foundValue->second; @@ -51,4 +51,9 @@ void StandardHttpResponse::AddHeader(const Aws::String& headerName, const Aws::S headerMap[StringUtils::ToLower(headerName.c_str())] = headerValue; } +void StandardHttpResponse::AddHeader(const Aws::String& headerName, Aws::String&& headerValue) +{ + headerMap.emplace(StringUtils::ToLower(headerName.c_str()), std::move(headerValue)); +} + diff --git a/tests/aws-cpp-sdk-sqs-integration-tests/QueueOperationTest.cpp b/tests/aws-cpp-sdk-sqs-integration-tests/QueueOperationTest.cpp index 52117b13d89c..2fa214a4ac4c 100644 --- a/tests/aws-cpp-sdk-sqs-integration-tests/QueueOperationTest.cpp +++ b/tests/aws-cpp-sdk-sqs-integration-tests/QueueOperationTest.cpp @@ -101,6 +101,7 @@ class QueueOperationTest : public ::testing::Test config.proxyHost = PROXY_HOST; config.proxyPort = PROXY_PORT; #endif + config.requestTimeoutMs = 20000; return config; }