diff --git a/.vscode/cspell.json b/.vscode/cspell.json index cd1346732a..f4dc3a9d1a 100644 --- a/.vscode/cspell.json +++ b/.vscode/cspell.json @@ -84,6 +84,7 @@ "ncus", "Niels", "nlohmann", + "nohup", "nostd", "noclean", "NOCLOSE", @@ -99,6 +100,7 @@ "pdbs", "Piotrowski", "PUCHAR", + "PVOID", "pwsh", "Ragrs", "Ragzrs", diff --git a/eng/pipelines/templates/jobs/archetype-sdk-client.yml b/eng/pipelines/templates/jobs/archetype-sdk-client.yml index 14883215a9..6a1652faec 100644 --- a/eng/pipelines/templates/jobs/archetype-sdk-client.yml +++ b/eng/pipelines/templates/jobs/archetype-sdk-client.yml @@ -32,6 +32,12 @@ parameters: - name: TestEnv type: object default: [] + - name: PreTestSteps + type: object + default: [] + - name: PostTestSteps + type: object + default: [] jobs: - template: /eng/common/pipelines/templates/jobs/archetype-sdk-tests-generate.yml @@ -53,6 +59,8 @@ jobs: LineCoverageTarget: ${{ parameters.LineCoverageTarget }} BranchCoverageTarget: ${{ parameters.BranchCoverageTarget }} TestEnv: ${{ parameters.TestEnv }} + PreTestSteps: ${{ parameters.PreTestSteps }} + PostTestSteps: ${{ parameters.PostTestSteps }} # Disable build for cpp - client - ${{ if ne(parameters.ServiceDirectory, 'not-specified' )}}: diff --git a/eng/pipelines/templates/jobs/ci.tests.yml b/eng/pipelines/templates/jobs/ci.tests.yml index 9f27a6f855..87b5eb0f82 100644 --- a/eng/pipelines/templates/jobs/ci.tests.yml +++ b/eng/pipelines/templates/jobs/ci.tests.yml @@ -46,6 +46,13 @@ parameters: - name: UsePlatformContainer type: boolean default: false + - name: PreTestSteps + type: object + default: [] + - name: PostTestSteps + type: object + default: [] + jobs: - job: @@ -137,6 +144,8 @@ jobs: BuildArgs: "$(BuildArgs)" Env: "$(CmakeEnvArg)" + - ${{ parameters.PreTestSteps }} + - pwsh: | ctest ` -C Debug ` @@ -148,6 +157,8 @@ jobs: workingDirectory: build displayName: Test + - ${{ parameters.PostTestSteps }} + - task: PublishTestResults@2 inputs: testResultsFormat: cTest diff --git a/eng/pipelines/templates/jobs/live.tests.yml b/eng/pipelines/templates/jobs/live.tests.yml index e5272766e3..19e73bb0b5 100644 --- a/eng/pipelines/templates/jobs/live.tests.yml +++ b/eng/pipelines/templates/jobs/live.tests.yml @@ -36,6 +36,12 @@ parameters: - name: UsePlatformContainer type: boolean default: false +- name: PreTestSteps + type: object + default: [] +- name: PostTestSteps + type: object + default: [] jobs: - job: ValidateLive @@ -126,6 +132,8 @@ jobs: Location: ${{ coalesce(parameters.Location, parameters.CloudConfig.Location) }} SubscriptionConfiguration: ${{ parameters.CloudConfig.SubscriptionConfiguration }} + - ${{ parameters.PreTestSteps }} + # For non multi-config generator use the same build configuration to run tests # We don't need to set it to invoke ctest # Visual Studio generator used in CI is a multi-config generator. @@ -141,6 +149,8 @@ jobs: succeeded(), ne(variables['RunSamples'], '1')) + - ${{ parameters.PostTestSteps }} + - task: PublishTestResults@2 inputs: testResultsFormat: cTest diff --git a/eng/pipelines/templates/stages/archetype-sdk-client.yml b/eng/pipelines/templates/stages/archetype-sdk-client.yml index a9bc6a575c..334e580647 100644 --- a/eng/pipelines/templates/stages/archetype-sdk-client.yml +++ b/eng/pipelines/templates/stages/archetype-sdk-client.yml @@ -67,6 +67,12 @@ parameters: - name: UnsupportedClouds type: string default: '' +- name: PreTestSteps + type: object + default: [] +- name: PostTestSteps + type: object + default: [] stages: @@ -96,6 +102,8 @@ stages: ${{ if eq(parameters.ServiceDirectory, 'template') }}: TestPipeline: true TestEnv: ${{ parameters.TestEnv }} + PreTestSteps: ${{ parameters.PreTestSteps }} + PostTestSteps: ${{ parameters.PostTestSteps }} - ${{ if and(eq(variables['System.TeamProject'], 'internal'), ne(parameters.LiveTestCtestRegex, '')) }}: - template: /eng/pipelines/templates/stages/archetype-sdk-tests.yml @@ -110,6 +118,8 @@ stages: Clouds: ${{ parameters.Clouds }} SupportedClouds: ${{ parameters.SupportedClouds }} UnsupportedClouds: ${{ parameters.UnsupportedClouds }} + PreTestSteps: ${{ parameters.PreTestSteps }} + PostTestSteps: ${{ parameters.PostTestSteps }} - ${{ if and(eq(variables['System.TeamProject'], 'internal'), not(endsWith(variables['Build.DefinitionName'], ' - tests'))) }}: - template: archetype-cpp-release.yml diff --git a/eng/pipelines/templates/stages/archetype-sdk-tests.yml b/eng/pipelines/templates/stages/archetype-sdk-tests.yml index 737de7b297..a4aba5565b 100644 --- a/eng/pipelines/templates/stages/archetype-sdk-tests.yml +++ b/eng/pipelines/templates/stages/archetype-sdk-tests.yml @@ -29,6 +29,12 @@ parameters: - name: UnsupportedClouds type: string default: '' +- name: PreTestSteps + type: object + default: [] +- name: PostTestSteps + type: object + default: [] stages: - ${{ each cloud in parameters.CloudConfig }}: @@ -57,3 +63,5 @@ stages: Coverage: ${{ parameters.Coverage}} CoverageReportPath: ${{ parameters.CoverageReportPath}} TimeoutInMinutes: ${{ parameters.TimeoutInMinutes}} + PreTestSteps: ${{ parameters.PreTestSteps }} + PostTestSteps: ${{ parameters.PostTestSteps }} diff --git a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt index bf3077e676..6d0dc15151 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt +++ b/sdk/core/azure-core-tracing-opentelemetry/CMakeLists.txt @@ -36,7 +36,7 @@ if (BUILD_AZURE_CORE_TRACING_OPENTELEMETRY) find_package(azure-core-cpp REQUIRED) endif() endif() - find_package(opentelemetry-cpp "1.3.0" CONFIG REQUIRED) + find_package(opentelemetry-cpp CONFIG REQUIRED) set( AZURE_CORE_OPENTELEMETRY_HEADER diff --git a/sdk/core/azure-core-tracing-opentelemetry/test/ut/azure_core_otel_test.cpp b/sdk/core/azure-core-tracing-opentelemetry/test/ut/azure_core_otel_test.cpp index bba341df18..ded42c49d3 100644 --- a/sdk/core/azure-core-tracing-opentelemetry/test/ut/azure_core_otel_test.cpp +++ b/sdk/core/azure-core-tracing-opentelemetry/test/ut/azure_core_otel_test.cpp @@ -531,6 +531,9 @@ TEST_F(OpenTelemetryTests, SetStatus) span->SetStatus(Azure::Core::Tracing::_internal::SpanStatus::Error, {}); span->SetStatus(Azure::Core::Tracing::_internal::SpanStatus::Ok, {}); + EXPECT_THROW( + span->SetStatus(static_cast(357), {}), + std::runtime_error); span->End({}); @@ -553,7 +556,7 @@ TEST_F(OpenTelemetryTests, SetStatus) span->SetStatus(Azure::Core::Tracing::_internal::SpanStatus::Error, "Something went wrong."); - span->End({}); + span->End(Azure::DateTime(std::chrono::system_clock::now())); // Return the collected spans. auto spans = m_spanData->GetSpans(); diff --git a/sdk/core/azure-core/CMakeLists.txt b/sdk/core/azure-core/CMakeLists.txt index 978a59def3..6d94397403 100644 --- a/sdk/core/azure-core/CMakeLists.txt +++ b/sdk/core/azure-core/CMakeLists.txt @@ -42,14 +42,20 @@ if(BUILD_TRANSPORT_CURL) src/http/curl/curl_connection_pool_private.hpp src/http/curl/curl_connection_private.hpp src/http/curl/curl_session_private.hpp - ) + src/http/curl/curl_websockets.cpp + ) SET(CURL_TRANSPORT_ADAPTER_INC inc/azure/core/http/curl_transport.hpp + inc/azure/core/http/websockets/curl_websockets_transport.hpp ) endif() if(BUILD_TRANSPORT_WINHTTP) - SET(WIN_TRANSPORT_ADAPTER_SRC src/http/winhttp/win_http_transport.cpp) - SET(WIN_TRANSPORT_ADAPTER_INC inc/azure/core/http/win_http_transport.hpp) + SET(WIN_TRANSPORT_ADAPTER_SRC + src/http/winhttp/win_http_transport.cpp + src/http/winhttp/win_http_websockets.cpp) + SET(WIN_TRANSPORT_ADAPTER_INC + inc/azure/core/http/win_http_transport.hpp + inc/azure/core/http/websockets/win_http_websockets_transport.hpp) endif() set( @@ -74,6 +80,8 @@ set( inc/azure/core/http/policies/policy.hpp inc/azure/core/http/raw_response.hpp inc/azure/core/http/transport.hpp + inc/azure/core/http/websockets/websockets.hpp + inc/azure/core/http/websockets/websockets_transport.hpp inc/azure/core/internal/client_options.hpp inc/azure/core/internal/contract.hpp inc/azure/core/internal/cryptography/sha_hash.hpp @@ -132,6 +140,8 @@ set( src/http/transport_policy.cpp src/http/url.cpp src/http/user_agent.cpp + src/http/websockets/websockets.cpp + src/http/websockets/websockets_impl.cpp src/io/body_stream.cpp src/io/random_access_file_body_stream.cpp src/logger.cpp diff --git a/sdk/core/azure-core/inc/azure/core/base64.hpp b/sdk/core/azure-core/inc/azure/core/base64.hpp index 689736ab53..cb5daadcf5 100644 --- a/sdk/core/azure-core/inc/azure/core/base64.hpp +++ b/sdk/core/azure-core/inc/azure/core/base64.hpp @@ -18,7 +18,10 @@ namespace Azure { namespace Core { /** * @brief Used to convert one form of data into another, for example encoding binary data into - * Base64 text. + * Base64 encoded octets. + * + * @note Base64 encoded data is a subset of the ASCII encoding (characters 0-127). As such, + * it can be considered a subset of UTF-8. */ class Convert final { private: @@ -31,17 +34,17 @@ namespace Azure { namespace Core { public: /** - * @brief Encodes the vector of binary data into UTF-8 encoded text represented as Base64. + * @brief Encodes a vector of binary data using Base64. * - * @param data The input vector that contains binary data that needs to be encoded. - * @return The UTF-8 encoded text in Base64. + * @param data The input vector that contains binary data to be encoded. + * @return The Base64 encoded contents of the vector. */ static std::string Base64Encode(const std::vector& data); /** - * @brief Decodes the UTF-8 encoded text represented as Base64 into binary data. + * @brief Decodes a Base64 encoded data into a vector of binary data. * - * @param text The input UTF-8 encoded text in Base64 that needs to be decoded. + * @param text Base64 encoded data to be decoded. * @return The decoded binary data. */ static std::vector Base64Decode(const std::string& text); diff --git a/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp b/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp index 9ef494cab1..a2225574cf 100644 --- a/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp +++ b/sdk/core/azure-core/inc/azure/core/cryptography/hash.hpp @@ -120,6 +120,11 @@ namespace Azure { namespace Core { namespace Cryptography { /** * @brief Represents the class for the MD5 hash function which maps binary data of an arbitrary * length to small binary data of a fixed length. + * + * @warning MD5 is a deprecated hashing algorithm and SHOULD NOT be used, + * unless it is used to implement a specific protocol (See RFC 6151 for more information + * about the weaknesses of the MD5 hash function). Client implementers should strongly prefer the + * SHA256, SHA384, and SHA512 hash functions. */ class Md5Hash final : public Hash { diff --git a/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp index a06fac3499..398708399e 100644 --- a/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/curl_transport.hpp @@ -13,6 +13,7 @@ #include "azure/core/http/transport.hpp" namespace Azure { namespace Core { namespace Http { + class CurlNetworkConnection; namespace _detail { /** @@ -46,7 +47,7 @@ namespace Azure { namespace Core { namespace Http { /** * @brief Set the libcurl connection options like a proxy and CA path. */ - struct CurlTransportOptions final + struct CurlTransportOptions { /** * @brief The string for the proxy is passed directly to the libcurl handle without any parsing. @@ -126,10 +127,17 @@ namespace Azure { namespace Core { namespace Http { /** * @brief Concrete implementation of an HTTP Transport that uses libcurl. */ - class CurlTransport final : public HttpTransport { + class CurlTransport : public HttpTransport { private: CurlTransportOptions m_options; + protected: + /** + * @brief Called when an HTTP response indicates the connection should be upgraded to + * a websocket. Takes ownership of the CurlNetworkConnection object. + */ + virtual void OnUpgradedConnection(std::unique_ptr&&){}; + public: /** * @brief Construct a new CurlTransport object. @@ -140,6 +148,12 @@ namespace Azure { namespace Core { namespace Http { { } + // See also: + // [Core Guidelines C.35: "A base class destructor should be either public + // and virtual or protected and + // non-virtual"](http://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c35-a-base-class-destructor-should-be-either-public-and-virtual-or-protected-and-non-virtual) + virtual ~CurlTransport() = default; + /** * @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse * diff --git a/sdk/core/azure-core/inc/azure/core/http/transport.hpp b/sdk/core/azure-core/inc/azure/core/http/transport.hpp index 64d6f6da15..448f060319 100644 --- a/sdk/core/azure-core/inc/azure/core/http/transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/transport.hpp @@ -66,6 +66,12 @@ namespace Azure { namespace Core { namespace Http { * @return A reference to this instance. */ HttpTransport& operator=(const HttpTransport& other) = default; + + /** + * @brief Returns true if the HttpTransport supports WebSockets (the ability to + * communicate bidirectionally on the TCP connection used by the HTTP transport). + */ + virtual bool HasWebSocketSupport() const { return false; } }; }}} // namespace Azure::Core::Http diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp new file mode 100644 index 0000000000..d3b1ecb1ad --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/curl_websockets_transport.hpp @@ -0,0 +1,155 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via CURL. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/curl_transport.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + struct CurlWebSocketTransportOptions : public Azure::Core::Http::CurlTransportOptions + { + }; + /** + * @brief Concrete implementation of a WebSocket Transport that uses libcurl. + */ + class CurlWebSocketTransport : public CurlTransport, public WebSocketTransport { + public: + /** + * @brief Construct a new CurlWebSocketTransport object. + * + * @param options Optional parameter to override the default options. + */ + CurlWebSocketTransport( + CurlWebSocketTransportOptions const& options = CurlWebSocketTransportOptions()) + : CurlTransport(options) + { + } + + /** + * @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse + * + * @param request an HTTP Request to be send. + * @param context A context to control the request lifetime. + * + * @return unique ptr to an HTTP RawResponse. + */ + virtual std::unique_ptr Send(Request& request, Context const& context) override; + + /** + * @brief Indicates if the transport natively supports websockets or not. + * + * @details For the CURL websocket transport, the transport does NOT support native websockets - + * it is the responsibility of the client of the WebSocketTransport to format WebSocket protocol + * elements. + */ + virtual bool HasBuiltInWebSocketSupport() override { return false; } + + /** + * @brief Closes the WebSocket handle. + * + */ + virtual void Close() override; + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * The first param is the close reason, the second is descriptive text. + */ + virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&) + override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief Retrieve the status of the close socket operation. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + const Azure::Core::Context&) override + { + throw std::runtime_error("Not implemented"); + } + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType, + std::vector const&, + Azure::Core::Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief Receive a frame of data from the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + */ + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const&) override + { + throw std::runtime_error("Not implemented"); + } + + // Non-Native WebSocket support. + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or until + * there is no more data to get from the socket. + * + * @param buffer Buffer to fill with data. + * @param bufferSize Size of buffer. + * @param context Context to control the request lifetime. + * + * @returns Buffer data received. + * + */ + virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) + override; + + /** + * @brief This method will use libcurl socket to write all the bytes from buffer. + * + * @param buffer Buffer to send. + * @param bufferSize Number of bytes to write. + * @param context Context for the operation. + */ + virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) + override; + + /** + * @brief returns true if this transport supports WebSockets, false otherwise. + */ + bool HasWebSocketSupport() const override { return true; } + + private: + // std::unique_ptr cannot be constructed on an incomplete type (CurlNetworkConnection), but + // std::shared_ptr can be. + std::shared_ptr m_upgradedConnection; + void OnUpgradedConnection( + std::unique_ptr&& upgradedConnection) override; + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp new file mode 100644 index 0000000000..a0a55d3bd4 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets.hpp @@ -0,0 +1,404 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Azure Core APIs implementing the WebSocket protocol [RFC 6455] + * (https://www.rfc-editor.org/rfc/rfc6455.html). + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/internal/client_options.hpp" +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + namespace _detail { + class WebSocketImplementation; + } + namespace _internal { + + enum class WebSocketFrameType : int + { + Unknown, + TextFrameReceived, + BinaryFrameReceived, + PeerClosedReceived, + }; + + enum class WebSocketErrorCode : uint16_t + { + OK = 1000, + EndpointDisappearing = 1001, + ProtocolError = 1002, + UnknownDataType = 1003, + Reserved1 = 1004, + NoStatusCodePresent = 1005, + ConnectionClosedWithoutCloseFrame = 1006, + InvalidMessageData = 1007, + PolicyViolation = 1008, + MessageTooLarge = 1009, + ExtensionNotFound = 1010, + UnexpectedError = 1011, + TlsHandshakeFailure = 1015, + }; + + class WebSocketTextFrame; + class WebSocketBinaryFrame; + class WebSocketPeerCloseFrame; + + namespace _detail { + class WebSocketImplementation; + } + /** @brief Statistics about data sent and received by the WebSocket. + * + * @remarks This class is primarily intended for test collateral and debugging to allow + * a caller to determine information about the status of a WebSocket. + * + * Note: Some of these statistics are not available if the underlying transport supports native + * websockets. + */ + struct WebSocketStatistics + { + /** @brief The number of WebSocket frames sent on this WebSocket. */ + uint32_t FramesSent; + + /** @brief The number of bytes of data sent to the peer on this WebSocket. */ + uint32_t BytesSent; + + /** @brief The number of WebSocket frames received from the peer. */ + uint32_t FramesReceived; + + /** @brief The number of bytes received from the peer. */ + uint32_t BytesReceived; + + /** @brief The number of "Ping" frames received from the peer. */ + uint32_t PingFramesReceived; + + /** @brief The number of "Ping" frames sent to the peer. */ + uint32_t PingFramesSent; + + /** @brief The number of "Pong" frames received from the peer. */ + uint32_t PongFramesReceived; + + /** @brief The number of "Pong" frames sent to the peer. */ + uint32_t PongFramesSent; + + /** @brief The number of "Text" frames received from the peer. */ + uint32_t TextFramesReceived; + + /** @brief The number of "Text" frames sent to the peer. */ + uint32_t TextFramesSent; + + /** @brief The number of "Binary" frames received from the peer. */ + uint32_t BinaryFramesReceived; + + /** @brief The number of "Binary" frames sent to the peer. */ + uint32_t BinaryFramesSent; + + /** @brief The number of "Continuation" frames sent to the peer. */ + uint32_t ContinuationFramesSent; + + /** @brief The number of "Continuation" frames received from the peer. */ + uint32_t ContinuationFramesReceived; + + /** @brief The number of "Close" frames received from the peer. */ + uint32_t CloseFramesReceived; + + /** @brief The number of frames received which were not processed. */ + uint32_t FramesDropped; + + /** @brief The number of frames received which were not returned because they were received + * after the Close() method was called. */ + + uint32_t FramesDroppedByClose; + /** @brief The number of frames dropped because they were over the maximum payload size. */ + + uint32_t FramesDroppedByPayloadSizeLimit; + /** @brief The number of frames dropped because they were out of compliance with the protocol. + */ + uint32_t FramesDroppedByProtocolError; + + /** @brief The number of reads performed on the transport.*/ + uint32_t TransportReads; + + /** @brief The number of bytes read from the transport. */ + uint32_t TransportReadBytes; + }; + + /** @brief A frame of data received from a WebSocket. + */ + class WebSocketFrame { + public: + /** @brief The type of frame received: Text, Binary or Close. */ + WebSocketFrameType FrameType{}; + + /** @brief True if the frame received is a "final" frame */ + bool IsFinalFrame{false}; + + /** @brief Returns the contents of the frame as a Text frame. + * @returns A WebSocketTextFrame containing the contents of the frame. + */ + std::shared_ptr AsTextFrame(); + + /** @brief Returns the contents of the frame as a Binary frame. + * @returns A WebSocketBinaryFrame containing the contents of the frame. + */ + + std::shared_ptr AsBinaryFrame(); + /** @brief Returns the contents of the frame as a Peer Close frame. + * @returns A WebSocketPeerCloseFrame containing the contents of the frame. + */ + std::shared_ptr AsPeerCloseFrame(); + + /** @brief Construct a new instance of a WebSocketFrame.*/ + WebSocketFrame() = default; + + /** @brief Construct a new instance of a WebSocketFrame with a specific frame type. + * @param frameType The type of frame received. + */ + WebSocketFrame(WebSocketFrameType frameType) : FrameType{frameType} {} + + /** @brief Construct a new instance of a WebSocketFrame with a specific frame type and final + * flag. + * @param frameType The type of frame received. + * @param isFinalFrame true if the frame is the final frame. + */ + WebSocketFrame(WebSocketFrameType frameType, bool isFinalFrame) + : FrameType{frameType}, IsFinalFrame{isFinalFrame} + { + } + }; + + /** @brief Contains the contents of a WebSocket Text frame.*/ + class WebSocketTextFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + private: + public: + /** @brief Constructs a new WebSocketTextFrame */ + WebSocketTextFrame() : WebSocketFrame(WebSocketFrameType::TextFrameReceived){}; + + /** @brief Text of the frame received from the remote peer. */ + std::string Text; + + private: + /** @brief Constructs a new WebSocketTextFrame + * @param isFinalFrame True if this is the final frame in a multi-frame message. + * @param body UTF-8 encoded text of the frame data. + * @param size Length in bytes of the frame body. + */ + WebSocketTextFrame(bool isFinalFrame, uint8_t const* body, size_t size) + : WebSocketFrame{WebSocketFrameType::TextFrameReceived, isFinalFrame}, + Text(body, body + size) + { + } + }; + + /** @brief Contains the contents of a WebSocket Binary frame.*/ + class WebSocketBinaryFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + private: + public: + /** @brief Constructs a new WebSocketBinaryFrame */ + WebSocketBinaryFrame() : WebSocketFrame(WebSocketFrameType::BinaryFrameReceived){}; + + /** @brief Binary frame data received from the remote peer. */ + std::vector Data; + + /** @brief Constructs a new WebSocketBinaryFrame + * @param isFinal True if this is the final frame in a multi-frame message. + * @param body binary of the frame data. + * @param size Length in bytes of the frame body. + */ + private: + WebSocketBinaryFrame(bool isFinal, uint8_t const* body, size_t size) + : WebSocketFrame{WebSocketFrameType::BinaryFrameReceived, isFinal}, + Data(body, body + size) + { + } + }; + + /** @brief Contains the contents of a WebSocket Close frame.*/ + class WebSocketPeerCloseFrame : public WebSocketFrame, + public std::enable_shared_from_this { + friend Azure::Core::Http::WebSockets::_detail::WebSocketImplementation; + + public: + /** @brief Constructs a new WebSocketPeerCloseFrame */ + WebSocketPeerCloseFrame() : WebSocketFrame(WebSocketFrameType::PeerClosedReceived){}; + + /** @brief Status code sent from the remote peer. Typically a member of the WebSocketErrorCode + * enumeration */ + uint16_t RemoteStatusCode{}; + + /** @brief Optional text sent from the remote peer. */ + std::string RemoteCloseReason; + + private: + /** @brief Constructs a new WebSocketBinaryFrame + * @param remoteStatusCode Status code sent by the remote peer. + * @param remoteCloseReason Optional reason sent by the remote peer. + */ + WebSocketPeerCloseFrame(uint16_t remoteStatusCode, std::string const& remoteCloseReason) + : WebSocketFrame{WebSocketFrameType::PeerClosedReceived}, + RemoteStatusCode(remoteStatusCode), RemoteCloseReason(remoteCloseReason) + { + } + }; + + struct WebSocketOptions : Azure::Core::_internal::ClientOptions + { + /** + * @brief The set of protocols which are supported by this client + */ + std::vector Protocols = {}; + + /** + * @brief The protocol name of the service client. Used for the User-Agent header + * in the initial WebSocket handshake. + */ + std::string ServiceName; + + /** + * @brief The version of the service client. Used for the User-Agent header in the + * initial WebSocket handshake + */ + std::string ServiceVersion; + + /** + * @brief The period of time between ping operations, default is 60 seconds. + */ + std::chrono::duration PingInterval{std::chrono::seconds{60}}; + + /** + * @brief Construct an instance of a WebSocketOptions type. + * + * @param protocols Supported protocols for this websocket client. + */ + explicit WebSocketOptions(std::vector protocols) + : Azure::Core::_internal::ClientOptions{}, Protocols(protocols) + { + } + WebSocketOptions() = default; + }; + + class WebSocket { + public: + /** @brief Constructs a new instance of a WebSocket with the specified WebSocket options. + * + * @param remoteUrl The URL of the remote WebSocket server. + * @param options The options to use for the WebSocket. + */ + explicit WebSocket( + Azure::Core::Url const& remoteUrl, + WebSocketOptions const& options = WebSocketOptions{}); + + /** @brief Destroys an instance of a WebSocket. + */ + ~WebSocket(); + + /** @brief Opens a WebSocket connection to a remote server. + * + * @param context Context for the operation, used for cancellation and timeout. + */ + void Open(Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Closes a WebSocket connection to the remote server gracefully. + * + * @param context Context for the operation. + */ + void Close(Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Closes a WebSocket connection to the remote server with additional context. + * + * @param closeStatus 16 bit WebSocket error code. + * @param closeReason String describing the reason for closing the socket. + * @param context Context for the operation. + */ + void Close( + uint16_t closeStatus, + std::string const& closeReason = {}, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Sends a String frame to the remote server. + * + * @param textFrame UTF-8 encoded text to send. + * @param isFinalFrame if True, this is the final frame in a multi-frame message. + * @param context Context for the operation. + */ + void SendFrame( + std::string const& textFrame, + bool isFinalFrame = false, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Sends a Binary frame to the remote server. + * + * @param binaryFrame Binary data to send. + * @param isFinalFrame if True, this is the final frame in a multi-frame message. + * @param context Context for the operation. + */ + void SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame = false, + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief Receive a frame from the remote server. + * + * @param context Context for the operation. + * + * @returns The received WebSocket frame. + * + */ + std::shared_ptr ReceiveFrame( + Azure::Core::Context const& context = Azure::Core::Context{}); + + /** @brief AddHeader - Adds a header to the initial handshake. + * + * @note This API is ignored after the WebSocket is opened. + * + * @param headerName Name of header to add to the initial handshake request. + * @param headerValue Value of header to add. + */ + void AddHeader(std::string const& headerName, std::string const& headerValue); + + /** @brief Determine if the WebSocket is open. + * + * @returns true if the WebSocket is open, false otherwise. + */ + bool IsOpen() const; + + /** @brief Returns "true" if the configured websocket transport + * supports websockets in the transport, or if the websocket implementation + * is providing websocket protocol support. + * + * @returns true if the HTTP transport used for WebSocket support directly supports the + * WebSocket API. + */ + bool HasBuiltInWebSocketSupport() const; + + /** @brief Returns the protocol chosen by the remote server during the initial handshake. + * + * @returns The protocol negotiated between client and server. + */ + std::string const& GetNegotiatedProtocol() const; + + /** @brief Returns statistics about the WebSocket. + * + * @returns The statistics about the WebSocket. + */ + WebSocketStatistics GetStatistics() const; + + private: + std::unique_ptr + m_socketImplementation; + }; + } // namespace _internal +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp new file mode 100644 index 0000000000..1afda3c0d2 --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/websockets_transport.hpp @@ -0,0 +1,204 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief Utilities to be used by HTTP WebSocket transport implementations. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + /** + * @brief Base class for all WebSocket transport implementations. + */ + class WebSocketTransport { + public: + /** + * @brief Web Socket Frame type, one of Text or Binary. + */ + enum class NativeWebSocketFrameType + { + /** + * @brief Indicates that the frame is a partial UTF-8 encoded text frame - it is NOT the + * complete frame to be sent to the remote node. + */ + TextFragment, + /** + * @brief Indicates that the frame is either the complete UTF-8 encoded text frame to be sent + * to the remote node or the final frame of a multipart message. + */ + Text, + /** + * @brief Indicates that the frame is either the complete binary frame to be sent + * to the remote node or the final frame of a multipart message. + */ + Binary, + /** + * @brief Indicates that the frame is a partial binary frame - it is NOT the + * complete frame to be sent to the remote node. + */ + BinaryFragment, + + /** + * @brief Indicates that the frame is a "close" frame - the remote node + * sent a close frame. + */ + Closed, + }; + + /** @brief Close information returned from a WebSocket transport that has builtin support + * for WebSockets. + */ + struct NativeWebSocketCloseInformation + { + /** + * @brief Close response code. + */ + uint16_t CloseReason; + /** + * @brief Close reason. + */ + std::string CloseReasonDescription; + }; + /** @brief Frame information returned from a WebSocket transport that has builtin support + * for WebSockets. + */ + struct NativeWebSocketReceiveInformation + { + /** + * @brief Type of frame received. + */ + NativeWebSocketFrameType FrameType; + /** + * @brief Data received. + */ + std::vector FrameData; + }; + /** + * @brief Destructs `%WebSocketTransport`. + * + */ + virtual ~WebSocketTransport() {} + + /** + * @brief Indicates whether the transport natively supports WebSockets. + * + * @returns true if the transport has native websocket support, false otherwise. + */ + virtual bool HasBuiltInWebSocketSupport() = 0; + + /** + * @brief Closes the WebSocket. + * + * Does not notify the remote endpoint that the socket is being closed. + * + */ + virtual void Close() = 0; + + /**************/ + /* Native WebSocket support functions*/ + /**************/ + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + */ + virtual void NativeCloseSocket( + uint16_t status, + std::string const& disconnectReason, + Azure::Core::Context const& context) + = 0; + + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + Azure::Core::Context const& context) + = 0; + + /** + * @brief Send a frame of data to the remote node. + * + * @param frameType Frame type sent to the server, Text or Binary. + * @param frameData Frame data to be sent to the server. + * @param context Context for the operation. + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType frameType, + std::vector const& frameData, + Azure::Core::Context const& context) + = 0; + + /** + * @brief Receive a frame from the remote WebSocket server. + * + * @param context Context for the operation. + * + * @returns a tuple containing the Frame data received from the remote server and the type of + * data returned from the remote endpoint + */ + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const& context) + = 0; + + /**************/ + /* Non Native WebSocket support functions */ + /**************/ + + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or until + * there is no more data to get from the socket. + * + */ + virtual size_t ReadFromSocket(uint8_t* buffer, size_t bufferSize, Context const& context) = 0; + + /** + * @brief This method will use the raw socket to write all the bytes from buffer. + * + */ + virtual int SendBuffer(uint8_t const* buffer, size_t bufferSize, Context const& context) = 0; + + protected: + /** + * @brief Constructs a default instance of `%WebSocketTransport`. + * + */ + WebSocketTransport() = default; + + /** + * @brief Constructs `%HttpTransport` by copying another instance of `%HttpTransport`. + * + * @param other An instance to copy. + */ + WebSocketTransport(const WebSocketTransport& other) = default; + + /** + * @brief Constructs a WebSocketTransport from another WebSocketTransport. + * + * @param other An instance to move in. + */ + WebSocketTransport(WebSocketTransport&& other) = default; + + /** + * @brief Assigns one WebSocketTransport to another. + * + * @param other An instance to assign. + * + * @return A reference to this instance. + */ + WebSocketTransport& operator=(const WebSocketTransport& other) = default; + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp new file mode 100644 index 0000000000..8fdf5b533c --- /dev/null +++ b/sdk/core/azure-core/inc/azure/core/http/websockets/win_http_websockets_transport.hpp @@ -0,0 +1,140 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +/** + * @file + * @brief #Azure::Core::Http::WebSockets::WebSocketTransport implementation via WInHTTP. + */ + +#pragma once + +#include "azure/core/context.hpp" +#include "azure/core/http/http.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include "azure/core/http/win_http_transport.hpp" +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + /** + * @brief Concrete implementation of a WebSocket Transport that uses WinHTTP. + */ + class WinHttpWebSocketTransport : public WebSocketTransport, public WinHttpTransport { + + Azure::Core::Http::_detail::unique_HINTERNET m_socketHandle; + std::mutex m_sendMutex; + std::mutex m_receiveMutex; + + // Called by the + void OnUpgradedConnection( + Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) override; + + public: + /** + * @brief Construct a new WinHTTP WebSocket Transport. + * + * @param options Optional parameter to override the default options. + */ + WinHttpWebSocketTransport(WinHttpTransportOptions const& options = WinHttpTransportOptions()) + : WinHttpTransport(options) + { + } + + /** + * @brief Implements interface to send an HTTP Request and produce an HTTP RawResponse + * + * @param request an HTTP Request to be send. + * @param context A context to control the request lifetime. + * + * @return unique ptr to an HTTP RawResponse. + */ + virtual std::unique_ptr Send(Request& request, Context const& context) override; + + /** + * @brief Indicates if the transports natively websockets or not. + * + * @details For the WinHTTP websocket transport, the WinHTTP API supports websockets. + */ + virtual bool HasBuiltInWebSocketSupport() override { return true; } + + /** + * @brief Close the underlying WebSocket handle. + * + */ + virtual void Close() override; + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + * + */ + virtual void NativeCloseSocket(uint16_t, std::string const&, Azure::Core::Context const&) + override; + + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + virtual NativeWebSocketCloseInformation NativeGetCloseSocketInformation( + Azure::Core::Context const& context) override; + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native + * websockets. + * + * @brief frameType Frame type sent to the server, Text or Binary. + * @brief frameData Frame data to be sent to the server. + */ + virtual void NativeSendFrame( + NativeWebSocketFrameType, + std::vector const&, + Azure::Core::Context const&) override; + + virtual NativeWebSocketReceiveInformation NativeReceiveFrame( + Azure::Core::Context const&) override; + + // Non-Native WebSocket support. + /** + * @brief This function is used when working with streams to pull more data from the wire. + * Function will try to keep pulling data from socket until the buffer is all written or + * until there is no more data to get from the socket. + * + * @details Not implemented for WinHTTP websockets because WinHTTP implements websockets + * natively. + */ + virtual size_t ReadFromSocket(uint8_t*, size_t, Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + /** + * @brief This method will use sockets to write all the bytes from buffer. + * + * @details Not implemented for WinHTTP websockets because WinHTTP implements websockets + * natively. + * + */ + virtual int SendBuffer(uint8_t const*, size_t, Context const&) override + { + throw std::runtime_error("Not implemented."); + } + + bool HasWebSocketSupport() const override { return true; } + }; + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp index 6090d6b49a..459c400db3 100644 --- a/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp +++ b/sdk/core/azure-core/inc/azure/core/http/win_http_transport.hpp @@ -24,6 +24,7 @@ #include #endif +#include #include #include #include @@ -35,41 +36,22 @@ namespace Azure { namespace Core { namespace Http { constexpr static size_t DefaultUploadChunkSize = 1024 * 64; constexpr static size_t MaximumUploadChunkSize = 1024 * 1024; - struct HandleManager final - { - Context const& m_context; - Request& m_request; - HINTERNET m_connectionHandle; - HINTERNET m_requestHandle; - - HandleManager(Request& request, Context const& context) - : m_request(request), m_context(context) - { - m_connectionHandle = NULL; - m_requestHandle = NULL; - } - - ~HandleManager() + // unique_ptr class wrapping an HINTERNET handle + class HINTERNET_deleter { + public: + void operator()(HINTERNET handle) noexcept { - // Close the handles and set them to null to avoid multiple calls to WinHTTP to close the - // handles. - if (m_requestHandle) - { - WinHttpCloseHandle(m_requestHandle); - m_requestHandle = NULL; - } - - if (m_connectionHandle) + if (handle != nullptr) { - WinHttpCloseHandle(m_connectionHandle); - m_connectionHandle = NULL; + WinHttpCloseHandle(handle); } } }; + using unique_HINTERNET = std::unique_ptr; class WinHttpStream final : public Azure::Core::IO::BodyStream { private: - std::unique_ptr m_handleManager; + _detail::unique_HINTERNET m_requestHandle; bool m_isEOF; /** @@ -99,8 +81,8 @@ namespace Azure { namespace Core { namespace Http { size_t OnRead(uint8_t* buffer, size_t count, Azure::Core::Context const& context) override; public: - WinHttpStream(std::unique_ptr handleManager, int64_t contentLength) - : m_handleManager(std::move(handleManager)), m_contentLength(contentLength), + WinHttpStream(_detail::unique_HINTERNET& requestHandle, int64_t contentLength) + : m_requestHandle(std::move(requestHandle)), m_contentLength(contentLength), m_isEOF(false), m_streamTotalRead(0) { } @@ -130,28 +112,53 @@ namespace Azure { namespace Core { namespace Http { * @brief Concrete implementation of an HTTP transport that uses WinHTTP when sending and * receiving requests and responses over the wire. */ - class WinHttpTransport final : public HttpTransport { + class WinHttpTransport : public HttpTransport { private: WinHttpTransportOptions m_options; // This should remain immutable and not be modified after calling the ctor, to avoid threading // issues. - HINTERNET m_sessionHandle = NULL; - - HINTERNET CreateSessionHandle(); - void CreateConnectionHandle(std::unique_ptr<_detail::HandleManager>& handleManager); - void CreateRequestHandle(std::unique_ptr<_detail::HandleManager>& handleManager); - void Upload(std::unique_ptr<_detail::HandleManager>& handleManager); - void SendRequest(std::unique_ptr<_detail::HandleManager>& handleManager); - void ReceiveResponse(std::unique_ptr<_detail::HandleManager>& handleManager); + _detail::unique_HINTERNET m_sessionHandle; + + _detail::unique_HINTERNET CreateSessionHandle(); + _detail::unique_HINTERNET CreateConnectionHandle( + Azure::Core::Url const& url, + Azure::Core::Context const& context); + _detail::unique_HINTERNET CreateRequestHandle( + _detail::unique_HINTERNET const& connectionHandle, + Azure::Core::Url const& url, + Azure::Core::Http::HttpMethod const& method); + void Upload( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Http::Request& request, + Azure::Core::Context const& context); + void SendRequest( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Http::Request& request, + Azure::Core::Context const& context); + void ReceiveResponse( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Context const& context); int64_t GetContentLength( - std::unique_ptr<_detail::HandleManager>& handleManager, + _detail::unique_HINTERNET const& requestHandle, HttpMethod requestMethod, HttpStatusCode responseStatusCode); std::unique_ptr SendRequestAndGetResponse( - std::unique_ptr<_detail::HandleManager> handleManager, + _detail::unique_HINTERNET& requestHandle, HttpMethod requestMethod); + // Callback to allow a derived transport to extract the request handle. Used for WebSocket + // transports. + protected: + virtual void OnUpgradedConnection(_detail::unique_HINTERNET const&){}; + /** + * @brief Throw an exception based on the Win32 Error code + * + * @param exceptionMessage Message describing error. + * @param error Win32 Error code. + */ + void GetErrorAndThrow(const std::string& exceptionMessage, DWORD error = GetLastError()); + public: /** * @brief Constructs `%WinHttpTransport`. @@ -170,16 +177,11 @@ namespace Azure { namespace Core { namespace Http { */ virtual std::unique_ptr Send(Request& request, Context const& context) override; - ~WinHttpTransport() - { - // Close the handles and set them to null to avoid multiple calls to WinHTTP to close the - // handles. - if (m_sessionHandle) - { - WinHttpCloseHandle(m_sessionHandle); - m_sessionHandle = NULL; - } - } + // See also: + // [Core Guidelines C.35: "A base class destructor should be either public + // and virtual or protected and + // non-virtual"](http://isocpp.github.io/CppCoreGuidelines/CppCoreGuidelines#c35-a-base-class-destructor-should-be-either-public-and-virtual-or-protected-and-non-virtual) + virtual ~WinHttpTransport() = default; }; }}} // namespace Azure::Core::Http diff --git a/sdk/core/azure-core/inc/azure/core/internal/cryptography/sha_hash.hpp b/sdk/core/azure-core/inc/azure/core/internal/cryptography/sha_hash.hpp index 446243fe69..87ecf1ddc1 100644 --- a/sdk/core/azure-core/inc/azure/core/internal/cryptography/sha_hash.hpp +++ b/sdk/core/azure-core/inc/azure/core/internal/cryptography/sha_hash.hpp @@ -17,6 +17,62 @@ namespace Azure { namespace Core { namespace Cryptography { namespace _internal { + /** + * @brief Defines #Sha1Hash. + * + * @warning SHA1 is a deprecated hashing algorithm and SHOULD NOT be used, + * unless it is used to implement a specific protocol (for instance, RFC 6455 and + * RFC 7517 both require the use of SHA1 hashes). SHA256, SHA384, and SHA512 are all preferred to + * SHA1. + * + */ + class Sha1Hash final : public Azure::Core::Cryptography::Hash { + public: + /** + * @brief Construct a default instance of #Sha1Hash. + * + */ + Sha1Hash(); + + /** + * @brief Cleanup any state when destroying the instance of #Sha1Hash. + * + */ + ~Sha1Hash() {} + + private: + /** + * @brief Underlying implementation based on the OS. + * + */ + std::unique_ptr m_portableImplementation; + + /** + * @brief Computes the hash value of the specified binary input data, including any previously + * appended. + * @param data The pointer to binary data to compute the hash value for. + * @param length The size of the data provided. + * @return The computed SHA1 hash value corresponding to the input provided including any + * previously appended. + */ + std::vector OnFinal(const uint8_t* data, size_t length) override + { + return m_portableImplementation->Final(data, length); + } + + /** + * @brief Used to append partial binary input data to compute the SHA1 hash in a streaming + * fashion. + * @remark Once all the data has been added, call #Final() to get the computed hash value. + * @param data The pointer to the current block of binary data that is used for hash + * calculation. + * @param length The size of the data provided. + */ + void OnAppend(const uint8_t* data, size_t length) override + { + return m_portableImplementation->Append(data, length); + } + }; /** * @brief Defines #Sha256Hash. * diff --git a/sdk/core/azure-core/src/base64.cpp b/sdk/core/azure-core/src/base64.cpp index 867e796bc5..a2e214c5bc 100644 --- a/sdk/core/azure-core/src/base64.cpp +++ b/sdk/core/azure-core/src/base64.cpp @@ -313,10 +313,10 @@ static void Base64WriteIntAsFourBytes(char* destination, int32_t value) destination[0] = static_cast(value & 0xFF); } -std::string Base64Encode(const std::vector& data) +std::string Base64Encode(uint8_t const* const data, size_t length) { size_t sourceIndex = 0; - auto inputSize = data.size(); + auto inputSize = length; auto maxEncodedSize = ((inputSize + 2) / 3) * 4; // Use a string with size to the max possible result std::string encodedResult(maxEncodedSize, '0'); @@ -490,7 +490,7 @@ namespace Azure { namespace Core { std::string Convert::Base64Encode(const std::vector& data) { - return ::Base64Encode(data); + return ::Base64Encode(data.data(), data.size()); } std::vector Convert::Base64Decode(const std::string& text) diff --git a/sdk/core/azure-core/src/cryptography/md5.cpp b/sdk/core/azure-core/src/cryptography/md5.cpp index a4bbaa217c..79a4a1210f 100644 --- a/sdk/core/azure-core/src/cryptography/md5.cpp +++ b/sdk/core/azure-core/src/cryptography/md5.cpp @@ -194,7 +194,6 @@ class Md5OpenSSL final : public Azure::Core::Cryptography::Hash { } if (1 != EVP_DigestInit_ex(m_context, EVP_md5(), NULL)) { - EVP_MD_CTX_free(m_context); throw std::runtime_error("Crypto error while init Md5Hash."); } } diff --git a/sdk/core/azure-core/src/cryptography/sha_hash.cpp b/sdk/core/azure-core/src/cryptography/sha_hash.cpp index 3ef52b4fb9..95b06b7d7a 100644 --- a/sdk/core/azure-core/src/cryptography/sha_hash.cpp +++ b/sdk/core/azure-core/src/cryptography/sha_hash.cpp @@ -26,6 +26,7 @@ namespace { enum class SHASize { + SHA1, SHA256, SHA384, SHA512 @@ -65,6 +66,13 @@ class SHAWithOpenSSL final : public Azure::Core::Cryptography::Hash { } switch (size) { + case SHASize::SHA1: { + if (1 != EVP_DigestInit_ex(m_context, EVP_sha1(), NULL)) + { + throw std::runtime_error("Crypto error while initializing Sha1Hash."); + } + break; + } case SHASize::SHA256: { if (1 != EVP_DigestInit_ex(m_context, EVP_sha256(), NULL)) { @@ -97,6 +105,11 @@ class SHAWithOpenSSL final : public Azure::Core::Cryptography::Hash { } // namespace +Azure::Core::Cryptography::_internal::Sha1Hash::Sha1Hash() + : m_portableImplementation(std::make_unique(SHASize::SHA1)) +{ +} + Azure::Core::Cryptography::_internal::Sha256Hash::Sha256Hash() : m_portableImplementation(std::make_unique(SHASize::SHA256)) { @@ -222,6 +235,11 @@ class SHAWithBCrypt final : public Azure::Core::Cryptography::Hash { } // namespace +Azure::Core::Cryptography::_internal::Sha1Hash::Sha1Hash() + : m_portableImplementation(std::make_unique(BCRYPT_SHA1_ALGORITHM)) +{ +} + Azure::Core::Cryptography::_internal::Sha256Hash::Sha256Hash() : m_portableImplementation(std::make_unique(BCRYPT_SHA256_ALGORITHM)) { diff --git a/sdk/core/azure-core/src/environment_log_level_listener.cpp b/sdk/core/azure-core/src/environment_log_level_listener.cpp index ffcf1e6076..98ef17dcf3 100644 --- a/sdk/core/azure-core/src/environment_log_level_listener.cpp +++ b/sdk/core/azure-core/src/environment_log_level_listener.cpp @@ -9,6 +9,7 @@ #include #include +#include using Azure::Core::_internal::Environment; using namespace Azure::Core::Diagnostics; @@ -114,7 +115,8 @@ EnvironmentLogLevelListener::GetLogListener() << Azure::DateTime(std::chrono::system_clock::now()) .ToString( DateTime::DateFormat::Rfc3339, DateTime::TimeFractionFormat::AllDigits) - << "] " << LogLevelToConsoleString(level) << " : " << message << std::endl; + << " T: " << std::this_thread::get_id() << "] " << LogLevelToConsoleString(level) + << " : " << message << std::endl; }; return consoleLogger; diff --git a/sdk/core/azure-core/src/http/curl/curl.cpp b/sdk/core/azure-core/src/http/curl/curl.cpp index 6c5123907b..ea65dbd39a 100644 --- a/sdk/core/azure-core/src/http/curl/curl.cpp +++ b/sdk/core/azure-core/src/http/curl/curl.cpp @@ -39,9 +39,13 @@ template // C26812: The enum type 'CURLoption' is un-scoped. Prefer 'enum class' over 'enum' (Enum.3) #pragma warning(disable : 26812) #endif -inline bool SetLibcurlOption(CURL* handle, CURLoption option, T value, CURLcode* outError) +inline bool SetLibcurlOption( + Azure::Core::Http::_detail::unique_CURL const& handle, + CURLoption option, + T value, + CURLcode* outError) { - *outError = curl_easy_setopt(handle, option, value); + *outError = curl_easy_setopt(handle.get(), option, value); return *outError == CURLE_OK; } #if defined(_MSC_VER) @@ -134,15 +138,7 @@ void WinSocketSetBuffSize(curl_socket_t socket) // if WSAloctl succeeded (returned 0), set the socket buffer size. // Specifies the total per-socket buffer space reserved for sends. // https://docs.microsoft.com/windows/win32/api/winsock/nf-winsock-setsockopt - auto result = setsockopt(socket, SOL_SOCKET, SO_SNDBUF, (const char*)&ideal, sizeof(ideal)); - - if (Log::ShouldWrite(Logger::Level::Verbose)) - { - Log::Write( - Logger::Level::Verbose, - LogMsgPrefix + "Windows - calling setsockopt after uploading chunk. ideal = " - + std::to_string(ideal) + " result = " + std::to_string(result)); - } + setsockopt(socket, SOL_SOCKET, SO_SNDBUF, (const char*)&ideal, sizeof(ideal)); } } #endif @@ -321,6 +317,14 @@ std::unique_ptr CurlTransport::Send(Request& request, Context const throw Azure::Core::Http::TransportException( "Error while sending request. " + std::string(curl_easy_strerror(performing))); } + if (HasWebSocketSupport()) + { + std::unique_ptr upgradedConnection(session->ExtractConnection()); + if (upgradedConnection) + { + OnUpgradedConnection(std::move(upgradedConnection)); + } + } Log::Write( Logger::Level::Verbose, @@ -421,6 +425,18 @@ CURLcode CurlSession::Perform(Context const& context) return result; } +std::unique_ptr CurlSession::ExtractConnection() +{ + if (m_connectionUpgraded) + { + return std::move(m_connection); + } + else + { + return nullptr; + } +} + // Creates an HTTP Response with specific bodyType static std::unique_ptr CreateHTTPResponse( uint8_t const* const begin, @@ -484,7 +500,10 @@ CURLcode CurlConnection::SendBuffer( { size_t sentBytesPerRequest = 0; sendResult = curl_easy_send( - m_handle, buffer + sentBytesTotal, bufferSize - sentBytesTotal, &sentBytesPerRequest); + m_handle.get(), + buffer + sentBytesTotal, + bufferSize - sentBytesTotal, + &sentBytesPerRequest); switch (sendResult) { @@ -719,11 +738,19 @@ void CurlSession::ReadStatusLineAndHeadersFromRawResponse( auto connectionHeader = headers.find("connection"); if (connectionHeader != headers.end()) { - if (connectionHeader->second == "close") + if (Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( + connectionHeader->second, "close")) { // Use connection shut-down so it won't be moved it back to the connection pool. m_connection->Shutdown(); } + // If the server indicated that the connection header is "upgrade", it means that this + // is a WebSocket connection so the caller may be upgrading the connection. + if (Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( + connectionHeader->second, "upgrade")) + { + m_connectionUpgraded = true; + } } auto isContentLengthHeaderInResponse = headers.find("content-length"); @@ -880,6 +907,11 @@ size_t CurlSession::OnRead(uint8_t* buffer, size_t count, Context const& context return 0; } + // If we no longer have a connection, read 0 bytes. + if (!m_connection) + { + return 0; + } // Read from socket when no more data on internal buffer // For chunk request, read a chunk based on chunk size totalRead = m_connection->ReadFromSocket(buffer, static_cast(readRequestLength), context); @@ -932,8 +964,7 @@ size_t CurlConnection::ReadFromSocket(uint8_t* buffer, size_t bufferSize, Contex size_t readBytes = 0; for (CURLcode readResult = CURLE_AGAIN; readResult == CURLE_AGAIN;) { - readResult = curl_easy_recv(m_handle, buffer, bufferSize, &readBytes); - + readResult = curl_easy_recv(m_handle.get(), buffer, bufferSize, &readBytes); switch (readResult) { case CURLE_AGAIN: { @@ -1294,15 +1325,12 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo return connection; } } - lock.unlock(); } // Creating a new connection is thread safe. No need to lock mutex here. // No available connection for the pool for the required host. Create one Log::Write(Logger::Level::Verbose, LogMsgPrefix + "Spawn new connection."); - - auto newHandle = std::unique_ptr(curl_easy_init()); - + unique_CURL newHandle(curl_easy_init(), CURL_deleter{}); if (!newHandle) { throw Azure::Core::Http::TransportException( @@ -1312,22 +1340,21 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo CURLcode result; // Libcurl setup before open connection (url, connect_only, timeout) - if (!SetLibcurlOption( - newHandle.get(), CURLOPT_URL, request.GetUrl().GetAbsoluteUrl().data(), &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_URL, request.GetUrl().GetAbsoluteUrl().data(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + std::string(curl_easy_strerror(result))); } - if (port != 0 && !SetLibcurlOption(newHandle.get(), CURLOPT_PORT, port, &result)) + if (port != 0 && !SetLibcurlOption(newHandle, CURLOPT_PORT, port, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " + std::string(curl_easy_strerror(result))); } - if (!SetLibcurlOption(newHandle.get(), CURLOPT_CONNECT_ONLY, 1L, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_CONNECT_ONLY, 1L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " @@ -1337,7 +1364,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo // Set timeout to 24h. Libcurl will fail uploading on windows if timeout is: // timeout >= 25 days. Fails as soon as trying to upload any data // 25 days < timeout > 1 days. Fail on huge uploads ( > 1GB) - if (!SetLibcurlOption(newHandle.get(), CURLOPT_TIMEOUT, 60L * 60L * 24L, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_TIMEOUT, 60L * 60L * 24L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName + ". " @@ -1346,8 +1373,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (options.ConnectionTimeout != Azure::Core::Http::_detail::DefaultConnectionTimeout) { - if (!SetLibcurlOption( - newHandle.get(), CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_CONNECTTIMEOUT_MS, options.ConnectionTimeout, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1362,7 +1388,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo */ if (options.Proxy) { - if (!SetLibcurlOption(newHandle.get(), CURLOPT_PROXY, options.Proxy->c_str(), &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_PROXY, options.Proxy->c_str(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1373,7 +1399,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (!options.CAInfo.empty()) { - if (!SetLibcurlOption(newHandle.get(), CURLOPT_CAINFO, options.CAInfo.c_str(), &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_CAINFO, options.CAInfo.c_str(), &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1388,7 +1414,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo sslOption |= CURLSSLOPT_NO_REVOKE; } - if (!SetLibcurlOption(newHandle.get(), CURLOPT_SSL_OPTIONS, sslOption, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_SSL_OPTIONS, sslOption, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1398,7 +1424,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (!options.SslVerifyPeer) { - if (!SetLibcurlOption(newHandle.get(), CURLOPT_SSL_VERIFYPEER, 0L, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_SSL_VERIFYPEER, 0L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1408,7 +1434,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo if (options.NoSignal) { - if (!SetLibcurlOption(newHandle.get(), CURLOPT_NOSIGNAL, 1L, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_NOSIGNAL, 1L, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1420,7 +1446,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo // curl-transport adapter supports only HTTP/1.1 // https://github.com/Azure/azure-sdk-for-cpp/issues/2848 // The libcurl uses HTTP/2 by default, if it can be negotiated with a server on handshake. - if (!SetLibcurlOption(newHandle.get(), CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_HTTP_VERSION, CURL_HTTP_VERSION_1_1, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1428,7 +1454,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo } // Make libcurl to support only TLS v1.2 or later - if (!SetLibcurlOption(newHandle.get(), CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2, &result)) + if (!SetLibcurlOption(newHandle, CURLOPT_SSLVERSION, CURL_SSLVERSION_TLSv1_2, &result)) { throw Azure::Core::Http::TransportException( _detail::DefaultFailedToGetNewConnectionTemplate + hostDisplayName @@ -1443,7 +1469,7 @@ std::unique_ptr CurlConnectionPool::ExtractOrCreateCurlCo + std::string(curl_easy_strerror(performResult))); } - return std::make_unique(newHandle.release(), connectionKey); + return std::make_unique(std::move(newHandle), connectionKey); } // Move the connection back to the connection pool. Push it to the front so it becomes the @@ -1507,5 +1533,4 @@ void CurlConnectionPool::MoveConnectionBackToPool( { Log::Write(Logger::Level::Verbose, "Clean thread running. Won't start a new one."); } - lock.unlock(); } diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp index dff3fad2f1..fdf1334360 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_pool_private.hpp @@ -126,18 +126,4 @@ namespace Azure { namespace Core { namespace Http { namespace _detail { std::thread m_cleanThread; }; - /** - * @brief std::default_delete for the CURL * type , used for std::unique_ptr - * - */ - class CURL_deleter { - public: - void operator()(CURL* handle) noexcept - { - if (handle != nullptr) - { - curl_easy_cleanup(handle); - } - } - }; }}}} // namespace Azure::Core::Http::_detail diff --git a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp index 4965848104..705ec2d862 100644 --- a/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_connection_private.hpp @@ -47,6 +47,23 @@ namespace Azure { namespace Core { namespace Http { // Define the maximun allowed connections per host-index in the pool. If this number is reached // for the host-index, next connections trying to be added to the pool will be ignored. constexpr static int32_t MaxConnectionsPerIndex = 1024; + + /** + * @brief std::default_delete for the CURL * type , used for std::unique_ptr + * + */ + class CURL_deleter { + public: + void operator()(CURL* handle) noexcept + { + if (handle != nullptr) + { + curl_easy_cleanup(handle); + } + } + }; + using unique_CURL = std::unique_ptr; + } // namespace _detail /** @@ -122,7 +139,7 @@ namespace Azure { namespace Core { namespace Http { */ class CurlConnection final : public CurlNetworkConnection { private: - CURL* m_handle; + _detail::unique_CURL m_handle; curl_socket_t m_curlSocket; std::chrono::steady_clock::time_point m_lastUseTime; std::string m_connectionKey; @@ -135,8 +152,8 @@ namespace Azure { namespace Core { namespace Http { * * @param connectionPropertiesKey CURL connection properties key */ - CurlConnection(CURL* handle, std::string connectionPropertiesKey) - : m_handle(handle), m_connectionKey(std::move(connectionPropertiesKey)) + CurlConnection(_detail::unique_CURL&& handle, std::string connectionPropertiesKey) + : m_handle(std::move(handle)), m_connectionKey(std::move(connectionPropertiesKey)) { // Get the socket that libcurl is using from handle. Will use this to wait while // reading/writing @@ -146,7 +163,7 @@ namespace Azure { namespace Core { namespace Http { // C26812: The enum type 'CURLcode' is un-scoped. Prefer 'enum class' over 'enum' (Enum.3) #pragma warning(disable : 26812) #endif - auto result = curl_easy_getinfo(m_handle, CURLINFO_ACTIVESOCKET, &m_curlSocket); + auto result = curl_easy_getinfo(m_handle.get(), CURLINFO_ACTIVESOCKET, &m_curlSocket); #if defined(_MSC_VER) #pragma warning(pop) #endif @@ -162,7 +179,7 @@ namespace Azure { namespace Core { namespace Http { * @brief Destructor. * @details Cleans up CURL (invokes `curl_easy_cleanup()`). */ - ~CurlConnection() override { curl_easy_cleanup(this->m_handle); } + ~CurlConnection() override {} std::string const& GetConnectionKey() const override { return this->m_connectionKey; } diff --git a/sdk/core/azure-core/src/http/curl/curl_session_private.hpp b/sdk/core/azure-core/src/http/curl/curl_session_private.hpp index 3e39a2c739..be77f85fbe 100644 --- a/sdk/core/azure-core/src/http/curl/curl_session_private.hpp +++ b/sdk/core/azure-core/src/http/curl/curl_session_private.hpp @@ -274,6 +274,12 @@ namespace Azure { namespace Core { namespace Http { size_t m_sessionTotalRead = 0; + /** + * @brief If True, the connection is going to be "upgraded" into a websocket connection, so + * block moving the connection to the pool. + */ + bool m_connectionUpgraded = false; + /** * @brief Internal buffer from a session used to read bytes from a socket. This buffer is only * used while constructing an HTTP RawResponse without adding a body to it. Customers would @@ -388,7 +394,7 @@ namespace Azure { namespace Core { namespace Http { // By not moving the connection back to the pool, it gets destroyed calling the connection // destructor to clean libcurl handle and close the connection. // IsEOF will also handle a connection that fail to complete an upload request. - if (IsEOF() && m_keepAlive) + if (IsEOF() && m_keepAlive && !m_connectionUpgraded) { _detail::CurlConnectionPool::g_curlConnectionPool.MoveConnectionBackToPool( std::move(m_connection), m_lastStatusCode); @@ -418,6 +424,13 @@ namespace Azure { namespace Core { namespace Http { * @return The size of the payload. */ int64_t Length() const override { return m_contentLength; } + + /** + * @brief Return the network connection if the server indicated that the connection is upgraded. + * + * @return The network connection, or null if the connection was not upgraded. + */ + std::unique_ptr ExtractConnection(); }; }}} // namespace Azure::Core::Http diff --git a/sdk/core/azure-core/src/http/curl/curl_websockets.cpp b/sdk/core/azure-core/src/http/curl/curl_websockets.cpp new file mode 100644 index 0000000000..49e182caee --- /dev/null +++ b/sdk/core/azure-core/src/http/curl/curl_websockets.cpp @@ -0,0 +1,82 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/http.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/platform.hpp" + +// Private include +#include "curl_connection_private.hpp" + +#if defined(AZ_PLATFORM_POSIX) +#include // for poll() +#include // for socket shutdown +#elif defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include // for WSAPoll(); +#endif + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + void CurlWebSocketTransport::Close() { m_upgradedConnection->Shutdown(); } + + // Send an HTTP request to the remote server. + std::unique_ptr CurlWebSocketTransport::Send( + Request& request, + Context const& context) + { + // CURL doesn't understand the ws and wss protocols, so change the URL to be http based. + std::string requestScheme(request.GetUrl().GetScheme()); + if (requestScheme == "wss" || requestScheme == "ws") + { + if (requestScheme == "wss") + { + request.GetUrl().SetScheme("https"); + } + else + { + request.GetUrl().SetScheme("http"); + } + } + return CurlTransport::Send(request, context); + } + + size_t CurlWebSocketTransport::ReadFromSocket( + uint8_t* buffer, + size_t bufferSize, + Context const& context) + { + return m_upgradedConnection->ReadFromSocket(buffer, bufferSize, context); + } + + /** + * @brief This method will use libcurl socket to write all the bytes from buffer. + * + */ + int CurlWebSocketTransport::SendBuffer( + uint8_t const* buffer, + size_t bufferSize, + Context const& context) + { + return m_upgradedConnection->SendBuffer(buffer, bufferSize, context); + } + + void CurlWebSocketTransport::OnUpgradedConnection( + std::unique_ptr&& upgradedConnection) + { + // Note that m_upgradedConnection is a std::shared_ptr. We define it as a std::shared_ptr + // because a std::shared_ptr can be declared on an incomplete type, while a std::unique_ptr + // cannot. + m_upgradedConnection = std::move(upgradedConnection); + } + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/src/http/websockets/websockets.cpp b/sdk/core/azure-core/src/http/websockets/websockets.cpp new file mode 100644 index 0000000000..65102f31ab --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets.cpp @@ -0,0 +1,106 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/context.hpp" +#include "websockets_impl.hpp" + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _internal { + + WebSocket::WebSocket(Azure::Core::Url const& remoteUrl, WebSocketOptions const& options) + : m_socketImplementation( + std::make_unique( + remoteUrl, + options)) + + { + } + WebSocket::~WebSocket() {} + + void WebSocket::Open(Azure::Core::Context const& context) + { + m_socketImplementation->Open(context); + } + void WebSocket::Close(Azure::Core::Context const& context) + { + m_socketImplementation->Close( + static_cast(WebSocketErrorCode::EndpointDisappearing), {}, context); + } + void WebSocket::Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context) + { + m_socketImplementation->Close(closeStatus, closeReason, context); + } + + void WebSocket::SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + m_socketImplementation->SendFrame(textFrame, isFinalFrame, context); + } + + void WebSocket::SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + m_socketImplementation->SendFrame(binaryFrame, isFinalFrame, context); + } + + WebSocketStatistics WebSocket::GetStatistics() const + { + return m_socketImplementation->GetStatistics(); + } + + bool WebSocket::HasBuiltInWebSocketSupport() const + { + return m_socketImplementation->HasBuiltInWebSocketSupport(); + } + + std::shared_ptr WebSocket::ReceiveFrame(Azure::Core::Context const& context) + { + return m_socketImplementation->ReceiveFrame(context); + } + + void WebSocket::AddHeader(std::string const& headerName, std::string const& headerValue) + { + m_socketImplementation->AddHeader(headerName, headerValue); + } + std::string const& WebSocket::GetNegotiatedProtocol() const + { + return m_socketImplementation->GetNegotiatedProtocol(); + } + + bool WebSocket::IsOpen() const { return m_socketImplementation->IsOpen(); } + + std::shared_ptr WebSocketFrame::AsTextFrame() + { + if (FrameType != WebSocketFrameType::TextFrameReceived) + { + throw std::logic_error("Cannot cast to TextFrameReceived."); + } + return static_cast(this)->shared_from_this(); + } + + std::shared_ptr WebSocketFrame::AsBinaryFrame() + { + if (FrameType != WebSocketFrameType::BinaryFrameReceived) + { + throw std::logic_error("Cannot cast to BinaryFrameReceived."); + } + return static_cast(this)->shared_from_this(); + } + + std::shared_ptr WebSocketFrame::AsPeerCloseFrame() + { + if (FrameType != WebSocketFrameType::PeerClosedReceived) + { + throw std::logic_error("Cannot cast to PeerClose."); + } + return static_cast(this)->shared_from_this(); + } + +}}}}} // namespace Azure::Core::Http::WebSockets::_internal diff --git a/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp b/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp new file mode 100644 index 0000000000..90a9a69e33 --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets_impl.cpp @@ -0,0 +1,879 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT +#include "websockets_impl.hpp" +#include "azure/core/base64.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/internal/cryptography/sha_hash.hpp" +// SUPPORT_NATIVE_TRANSPORT indicates if WinHTTP should be compiled with native transport support +// or not. +// Note that this is primarily required to improve the code coverage numbers in the CI pipeline. +#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER) +#include "azure/core/http/websockets/win_http_websockets_transport.hpp" +#define SUPPORT_NATIVE_TRANSPORT 1 +#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#define SUPPORT_NATIVE_TRANSPORT 0 +#endif +#include "azure/core/internal/diagnostics/log.hpp" +#include +#include +#include +#include +#include +#include +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail { + using namespace Azure::Core::Http::WebSockets::_internal; + using namespace Azure::Core::Diagnostics::_internal; + using namespace Azure::Core::Diagnostics; + using namespace std::chrono_literals; + + namespace { + std::string HexEncode(std::vector const& data, size_t length) + { + std::stringstream ss; + for (size_t i = 0; i < std::min(data.size(), length); i++) + { + ss << std::hex << std::setfill('0') << std::setw(2) << static_cast(data[i]); + } + return ss.str(); + } + } // namespace + + WebSocketImplementation::WebSocketImplementation( + Azure::Core::Url const& remoteUrl, + WebSocketOptions const& options) + : m_remoteUrl(remoteUrl), m_options(options), m_pingThread(this, m_options.PingInterval) + { + } + + void WebSocketImplementation::Open(Azure::Core::Context const& context) + { + if (m_state != SocketState::Invalid && m_state != SocketState::Closed) + { + throw std::runtime_error( + "Socket in unexpected state: " + std::to_string(static_cast(m_state))); + } + m_state = SocketState::Opening; + +#if defined(BUILD_TRANSPORT_WINHTTP_ADAPTER) + WinHttpTransportOptions transportOptions; + auto winHttpTransport + = std::make_shared( + transportOptions); + m_transport = std::static_pointer_cast(winHttpTransport); + m_options.Transport.Transport = std::static_pointer_cast(winHttpTransport); +#elif defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) + CurlWebSocketTransportOptions transportOptions; + transportOptions.HttpKeepAlive = false; + auto curlWebSockets + = std::make_shared(transportOptions); + + m_transport = std::static_pointer_cast(curlWebSockets); + m_options.Transport.Transport = std::static_pointer_cast(curlWebSockets); +#endif + + std::vector> perCallPolicies{}; + std::vector> perRetryPolicies{}; + // If the caller has told us a service name, add the telemetry policy to the pipeline to add + // a user agent header to the request. + if (!m_options.ServiceName.empty()) + { + perCallPolicies.push_back( + std::make_unique( + m_options.ServiceName, m_options.ServiceVersion, m_options.Telemetry)); + } + Azure::Core::Http::_internal::HttpPipeline openPipeline( + m_options, std::move(perRetryPolicies), std::move(perCallPolicies)); + + Azure::Core::Http::Request openSocketRequest( + Azure::Core::Http::HttpMethod::Get, m_remoteUrl, false); + + // Generate the random request key. Only used when the transport doesn't support websockets + // natively. + auto randomKey = GenerateRandomKey(); + auto encodedKey = Azure::Core::Convert::Base64Encode(randomKey); + if (!m_transport->HasBuiltInWebSocketSupport()) + { + // If the transport doesn't support WebSockets natively, set the standardized WebSocket + // upgrade headers. + openSocketRequest.SetHeader("Upgrade", "websocket"); + openSocketRequest.SetHeader("Connection", "upgrade"); + openSocketRequest.SetHeader("Sec-WebSocket-Version", "13"); + openSocketRequest.SetHeader("Sec-WebSocket-Key", encodedKey); + } + if (!m_options.Protocols.empty()) + { + + std::string protocols; + for (auto const& protocol : m_options.Protocols) + { + protocols += protocol; + protocols += ", "; + } + protocols = protocols.substr(0, protocols.size() - 2); + openSocketRequest.SetHeader("Sec-WebSocket-Protocol", protocols); + } + for (auto const& additionalHeader : m_headers) + { + openSocketRequest.SetHeader(additionalHeader.first, additionalHeader.second); + } + std::string remoteOrigin; + remoteOrigin = m_remoteUrl.GetScheme(); + remoteOrigin += "://"; + remoteOrigin += m_remoteUrl.GetHost(); + openSocketRequest.SetHeader("Origin", remoteOrigin); + + // Send the connect request to the WebSocket server. + auto response = openPipeline.Send(openSocketRequest, context); + + // Ensure that the server thinks we're switching protocols. If it doesn't, + // fail immediately. + if (response->GetStatusCode() != Azure::Core::Http::HttpStatusCode::SwitchingProtocols) + { + throw Azure::Core::Http::TransportException("Unexpected handshake response"); + } + + // Prove that the server received this socket request. + auto& responseHeaders = response->GetHeaders(); + if (!m_transport->HasBuiltInWebSocketSupport()) + { + auto socketAccept(responseHeaders.find("Sec-WebSocket-Accept")); + if (socketAccept == responseHeaders.end()) + { + throw Azure::Core::Http::TransportException("Missing Sec-WebSocket-Accept header"); + } + // Verify that the WebSocket server received *this* open request. + else + { + VerifySocketAccept(encodedKey, socketAccept->second); + } + m_initialBodyStream = response->ExtractBodyStream(); + m_pingThread.Start(m_transport); + } + + // Remember the protocol that the client chose. + auto chosenProtocol = responseHeaders.find("Sec-WebSocket-Protocol"); + if (chosenProtocol != responseHeaders.end()) + { + m_chosenProtocol = chosenProtocol->second; + } + + m_state = SocketState::Open; + } + bool WebSocketImplementation::HasBuiltInWebSocketSupport() + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + return m_transport->HasBuiltInWebSocketSupport(); + } + + std::string const& WebSocketImplementation::GetNegotiatedProtocol() + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + return m_chosenProtocol; + } + + void WebSocketImplementation::AddHeader(std::string const& header, std::string const& headerValue) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Closed && m_state != SocketState::Invalid) + { + throw std::runtime_error("AddHeader can only be called on closed sockets."); + } + m_headers.emplace(std::make_pair(header, headerValue)); + } + + void WebSocketImplementation::Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context) + { + std::unique_lock lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + // If we're closing an already closed socket, we're done. + if (m_state == SocketState::Closed) + { + return; + } + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + m_state = SocketState::Closing; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeCloseSocket(closeStatus, closeReason.c_str(), context); + } + else +#endif + { + // Send a going away message to the server. + std::vector closePayload; + closePayload.push_back(closeStatus >> 8); + closePayload.push_back(closeStatus & 0xff); + closePayload.insert(closePayload.end(), closeReason.begin(), closeReason.end()); + std::vector closeFrame = EncodeFrame(SocketOpcode::Close, true, closePayload); + SendTransportBuffer(closeFrame, context); + + // Unlock the state mutex before waiting for the close response to be received. + lock.unlock(); + // Drain the incoming series of frames from the server. + // Note that there might be in-flight frames that were sent from the other end of the + // WebSocket that we don't care about any more (since we're closing the WebSocket). So + // drain those frames. + auto closeResponse = ReceiveTransportFrame(context); + while (closeResponse && closeResponse->Opcode != SocketOpcode::Close) + { + m_receiveStatistics.FramesDroppedByClose++; + Log::Write( + Logger::Level::Warning, + "Received unexpected frame during close. Opcode: " + + std::to_string(static_cast(closeResponse->Opcode))); + closeResponse = ReceiveTransportFrame(context); + } + + // Re-acquire the state lock once we've received the close response. + lock.lock(); + m_stateOwner = std::this_thread::get_id(); + } + // Close the socket - after this point, the m_transport is invalid. + m_pingThread.Shutdown(); + m_transport->Close(); + m_state = SocketState::Closed; + } + + void WebSocketImplementation::SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + std::vector utf8text(textFrame.begin(), textFrame.end()); + m_receiveStatistics.TextFramesSent++; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeSendFrame( + (isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Text + : WebSocketTransport::NativeWebSocketFrameType::TextFragment), + utf8text, + context); + } + else +#endif + { + std::vector sendFrame = EncodeFrame(SocketOpcode::TextFrame, isFinalFrame, utf8text); + SendTransportBuffer(sendFrame, context); + } + } + + void WebSocketImplementation::SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context) + { + std::lock_guard lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + if (m_state != SocketState::Open) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + m_receiveStatistics.BinaryFramesSent++; +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + m_transport->NativeSendFrame( + (isFinalFrame ? WebSocketTransport::NativeWebSocketFrameType::Binary + : WebSocketTransport::NativeWebSocketFrameType::BinaryFragment), + binaryFrame, + context); + } + else +#endif + { + std::vector sendFrame + = EncodeFrame(SocketOpcode::BinaryFrame, isFinalFrame, binaryFrame); + + SendTransportBuffer(sendFrame, context); + } + } + + std::shared_ptr WebSocketImplementation::ReceiveFrame( + Azure::Core::Context const& context) + { + std::unique_lock lock(m_stateMutex); + m_stateOwner = std::this_thread::get_id(); + + if (m_state != SocketState::Open && m_state != SocketState::Closing) + { + throw std::runtime_error( + "Socket is not open." + std::to_string(static_cast(m_state))); + } + + // Unlock the state lock to allow other threads to run. If we don't, we might end up in in a + // situation where the server won't respond to the this client because all the client threads + // are blocked on the state lock. + lock.unlock(); + + std::shared_ptr frame; + // Loop until we receive an returnable incoming frame. + // If the incoming frame is returnable, we return the value from the frame. + while (true) + { + frame = ReceiveTransportFrame(context); + if (frame) + { + switch (frame->Opcode) + { + // When we receive a "ping" frame, we want to send a Pong frame back to the server. + case SocketOpcode::Ping: + Log::Write( + Logger::Level::Verbose, "Received Ping frame: " + HexEncode(frame->Payload, 16)); + SendPong(frame->Payload, context); + break; + + // We want to ignore all incoming "Pong" frames. + case SocketOpcode::Pong: + Log::Write( + Logger::Level::Verbose, "Received Pong frame: " + HexEncode(frame->Payload, 16)); + break; + + case SocketOpcode::BinaryFrame: + m_currentMessageType = SocketMessageType::Binary; + return std::shared_ptr(new WebSocketBinaryFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + + case SocketOpcode::TextFrame: + m_currentMessageType = SocketMessageType::Text; + return std::shared_ptr(new WebSocketTextFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + + case SocketOpcode::Close: { + if (frame->Payload.size() < 2) + { + throw std::runtime_error("Close response buffer is too short."); + } + // Encode the payload for close according to RFC 6455 + // section 5.5.1. The first two bytes of the payload contain the status code. + // The remainder of the payload is a UTF-8 encoded string. + uint16_t errorCode = 0; + errorCode |= (frame->Payload[0] << 8) & 0xff00; + errorCode |= (frame->Payload[1] & 0x00ff); + + // We received a close frame, mark the socket as closed. Make sure we + // reacquire the state lock before setting the state to closed. + lock.lock(); + m_stateOwner = std::this_thread::get_id(); + m_state = SocketState::Closed; + + return std::shared_ptr(new WebSocketPeerCloseFrame( + errorCode, std::string(frame->Payload.begin() + 2, frame->Payload.end()))); + } + + // Continuation frames need to be treated somewhat specially. + // We depend on the fact that the protocol requires that a Continuation frame + // only be sent if it is part of a multi-frame message whose previous frame was a Text + // or Binary frame. + case SocketOpcode::Continuation: + if (m_currentMessageType == SocketMessageType::Text) + { + if (frame->IsFinalFrame) + { + m_currentMessageType = SocketMessageType::Unknown; + } + return std::shared_ptr(new WebSocketTextFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + } + else if (m_currentMessageType == SocketMessageType::Binary) + { + if (frame->IsFinalFrame) + { + m_currentMessageType = SocketMessageType::Unknown; + } + return std::shared_ptr(new WebSocketBinaryFrame( + frame->IsFinalFrame, frame->Payload.data(), frame->Payload.size())); + } + else + { + m_receiveStatistics.FramesDroppedByProtocolError++; + throw std::runtime_error("Unknown message type and received continuation opcode"); + } + default: + throw std::runtime_error("Unknown frame type received."); + } + } + else + { + if (m_state != SocketState::Closed && m_state != SocketState::Closing) + { + throw std::runtime_error("Transport is at EOF, no frame to receive."); + } + // The socket was closed, most likely locally, so fake a close frame response. + return std::shared_ptr(new WebSocketPeerCloseFrame()); + } + + context.ThrowIfCancelled(); + } + } + + std::shared_ptr + WebSocketImplementation::ReceiveTransportFrame(Azure::Core::Context const& context) + { +#if SUPPORT_NATIVE_TRANSPORT + if (m_transport->HasBuiltInWebSocketSupport()) + { + auto payload = m_transport->NativeReceiveFrame(context); + m_receiveStatistics.FramesReceived++; + switch (payload.FrameType) + { + case WebSocketTransport::NativeWebSocketFrameType::Binary: + m_receiveStatistics.BinaryFramesReceived++; + return std::make_shared( + SocketOpcode::BinaryFrame, true, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::BinaryFragment: + m_receiveStatistics.BinaryFramesReceived++; + return std::make_shared( + SocketOpcode::BinaryFrame, false, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::Text: + m_receiveStatistics.TextFramesReceived++; + return std::make_shared( + SocketOpcode::TextFrame, true, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::TextFragment: + m_receiveStatistics.TextFramesReceived++; + return std::make_shared( + SocketOpcode::TextFrame, false, payload.FrameData); + case WebSocketTransport::NativeWebSocketFrameType::Closed: { + m_receiveStatistics.CloseFramesReceived++; + auto closeResult = m_transport->NativeGetCloseSocketInformation(context); + std::vector closePayload; + closePayload.push_back(closeResult.CloseReason >> 8); + closePayload.push_back(closeResult.CloseReason & 0xff); + closePayload.insert( + closePayload.end(), + closeResult.CloseReasonDescription.begin(), + closeResult.CloseReasonDescription.end()); + return std::make_shared(SocketOpcode::Close, true, closePayload); + } + default: + throw std::runtime_error("Unexpected frame type received."); + } + } + else +#endif + { + std::shared_ptr frame = DecodeFrame(context); + if (frame) + { + + // Handle statistics for the incoming frame. + m_receiveStatistics.FramesReceived++; + switch (frame->Opcode) + { + case SocketOpcode::Ping: { + m_receiveStatistics.PingFramesReceived++; + break; + } + case SocketOpcode::Pong: { + m_receiveStatistics.PongFramesReceived++; + break; + } + case SocketOpcode::TextFrame: { + m_receiveStatistics.TextFramesReceived++; + break; + } + case SocketOpcode::BinaryFrame: { + m_receiveStatistics.BinaryFramesReceived++; + break; + } + case SocketOpcode::Close: { + m_receiveStatistics.CloseFramesReceived++; + break; + } + case SocketOpcode::Continuation: { + m_receiveStatistics.ContinuationFramesReceived++; + break; + } + default: { + m_receiveStatistics.UnknownFramesReceived++; + break; + } + } + } + else + { + m_receiveStatistics.FramesDropped++; + } + return frame; + } + } + + WebSocketStatistics WebSocketImplementation::GetStatistics() const + { + WebSocketStatistics returnValue{}; + returnValue.FramesSent = m_receiveStatistics.FramesSent.load(); + returnValue.FramesReceived = m_receiveStatistics.FramesReceived.load(); + returnValue.BinaryFramesReceived = m_receiveStatistics.BinaryFramesReceived.load(); + returnValue.TextFramesReceived = m_receiveStatistics.TextFramesReceived.load(); + returnValue.BinaryFramesSent = m_receiveStatistics.BinaryFramesSent.load(); + returnValue.TextFramesSent = m_receiveStatistics.TextFramesSent.load(); + returnValue.PingFramesReceived = m_receiveStatistics.PingFramesReceived.load(); + returnValue.PongFramesReceived = m_receiveStatistics.PongFramesReceived.load(); + returnValue.PingFramesSent = m_receiveStatistics.PingFramesSent.load(); + returnValue.PongFramesSent = m_receiveStatistics.PongFramesSent.load(); + + returnValue.BytesSent = m_receiveStatistics.BytesSent.load(); + returnValue.BytesReceived = m_receiveStatistics.BytesReceived.load(); + returnValue.FramesDropped = m_receiveStatistics.FramesDropped.load(); + returnValue.FramesDroppedByClose = m_receiveStatistics.FramesDroppedByClose.load(); + returnValue.FramesDroppedByPayloadSizeLimit + = m_receiveStatistics.FramesDroppedByPayloadSizeLimit.load(); + returnValue.FramesDroppedByProtocolError + = m_receiveStatistics.FramesDroppedByProtocolError.load(); + returnValue.TransportReadBytes = m_receiveStatistics.TransportReadBytes.load(); + returnValue.TransportReads = m_receiveStatistics.TransportReads.load(); + return returnValue; + } + + std::vector WebSocketImplementation::EncodeFrame( + SocketOpcode opcode, + bool isFinal, + std::vector const& payload) + { + std::vector encodedFrame; + // Add opcode+fin. + encodedFrame.push_back(static_cast(opcode) | (isFinal ? 0x80 : 0)); + uint8_t maskAndLength = 0; + maskAndLength |= 0x80; + + // Payloads smaller than 125 bytes are encoded directly in the maskAndLength field. + uint64_t payloadSize = static_cast(payload.size()); + if (payloadSize <= 125) + { + maskAndLength |= static_cast(payload.size()); + } + else if (payloadSize <= 65535) + { + // Payloads greater than 125 whose size can fit in a 16 bit integer bytes + // are encoded as a 16 bit unsigned integer in network byte order. + maskAndLength |= 126; + } + else + { + // Payloads greater than 65536 have their length are encoded as a 64 bit unsigned integer + // in network byte order. + maskAndLength |= 127; + } + encodedFrame.push_back(maskAndLength); + // Encode a 16 bit length. + if (payloadSize > 125 && payloadSize <= 65535) + { + encodedFrame.push_back(static_cast(payload.size()) >> 8); + encodedFrame.push_back(static_cast(payload.size()) & 0xff); + } + // Encode a 64 bit length. + else if (payloadSize >= 65536) + { + + encodedFrame.push_back((payloadSize >> 56) & 0xff); + encodedFrame.push_back((payloadSize >> 48) & 0xff); + encodedFrame.push_back((payloadSize >> 40) & 0xff); + encodedFrame.push_back((payloadSize >> 32) & 0xff); + encodedFrame.push_back((payloadSize >> 24) & 0xff); + encodedFrame.push_back((payloadSize >> 16) & 0xff); + encodedFrame.push_back((payloadSize >> 8) & 0xff); + encodedFrame.push_back(payloadSize & 0xff); + } + // Calculate the masking key. This MUST be 4 bytes of high entropy random numbers used to + // mask the input data. + { + // Start by generating the mask - 4 bytes of random data. + std::vector mask = GenerateRandomBytes(4); + + // Append the mask to the payload. + encodedFrame.insert(encodedFrame.end(), mask.begin(), mask.end()); + + // And mask the payload before transmitting it. + size_t index = 0; + for (auto ch : payload) + { + encodedFrame.push_back(ch ^ mask[index % 4]); + index += 1; + } + } + + return encodedFrame; + } + std::shared_ptr + WebSocketImplementation::DecodeFrame(Azure::Core::Context const& context) + { + // Ensure single threaded access to receive this frame. + std::unique_lock lock(m_transportMutex); + if (IsTransportEof()) + { + throw std::runtime_error("Frame buffer is too small."); + } + uint8_t payloadByte = ReadTransportByte(context); + // If the transport is at EOF, then there is no payload data, so just return null. + if (IsTransportEof()) + { + return nullptr; + } + SocketOpcode opcode = static_cast(payloadByte & 0x7f); + bool isFinal = (payloadByte & 0x80) != 0; + payloadByte = ReadTransportByte(context); + if (IsTransportEof()) + { + return nullptr; + } + if (payloadByte & 0x80) + { + throw std::runtime_error("Server sent a frame with a reserved bit set."); + } + int64_t payloadLength = payloadByte & 0x7f; + if (payloadLength <= 125) + { + payloadByte += 1; + } + else if (payloadLength == 126) + { + payloadLength = ReadTransportShort(context); + } + else if (payloadLength == 127) + { + payloadLength = ReadTransportInt64(context); + } + else + { + throw std::logic_error("Unexpected payload length."); + } + if (IsTransportEof()) + { + return nullptr; + } + + std::vector payload(ReadTransportBytes(static_cast(payloadLength), context)); + if (IsTransportEof()) + { + return nullptr; + } + return std::make_shared(opcode, isFinal, payload); + } + + uint8_t WebSocketImplementation::ReadTransportByte(Azure::Core::Context const& context) + { + if (m_bufferPos >= m_bufferLen) + { + // Start by reading data from our initial body stream. + m_bufferLen = m_initialBodyStream->ReadToCount(m_buffer, m_bufferSize, context); + if (m_bufferLen == 0) + { + // If we run out of the initial stream, we need to read from the transport. + m_bufferLen = m_transport->ReadFromSocket(m_buffer, m_bufferSize, context); + m_receiveStatistics.TransportReads++; + m_receiveStatistics.TransportReadBytes += static_cast(m_bufferLen); + } + else + { + Azure::Core::Diagnostics::_internal::Log::Write( + Azure::Core::Diagnostics::Logger::Level::Informational, + "Read data from initial stream"); + } + m_bufferPos = 0; + if (m_bufferLen == 0) + { + m_eof = true; + return 0; + } + } + + m_receiveStatistics.BytesReceived++; + return m_buffer[m_bufferPos++]; + } + uint16_t WebSocketImplementation::ReadTransportShort(Azure::Core::Context const& context) + { + uint16_t result = ReadTransportByte(context); + result <<= 8; + result |= ReadTransportByte(context); + return result; + } + uint64_t WebSocketImplementation::ReadTransportInt64(Azure::Core::Context const& context) + { + uint64_t result = 0; + + result |= (static_cast(ReadTransportByte(context)) << 56 & 0xff00000000000000); + result |= (static_cast(ReadTransportByte(context)) << 48 & 0x00ff000000000000); + result |= (static_cast(ReadTransportByte(context)) << 40 & 0x0000ff0000000000); + result |= (static_cast(ReadTransportByte(context)) << 32 & 0x000000ff00000000); + result |= (static_cast(ReadTransportByte(context)) << 24 & 0x00000000ff000000); + result |= (static_cast(ReadTransportByte(context)) << 16 & 0x0000000000ff0000); + result |= (static_cast(ReadTransportByte(context)) << 8 & 0x000000000000ff00); + result |= static_cast(ReadTransportByte(context)); + return result; + } + std::vector WebSocketImplementation::ReadTransportBytes( + size_t readLength, + Azure::Core::Context const& context) + { + std::vector result; + size_t index = 0; + while (index < readLength) + { + uint8_t byte = ReadTransportByte(context); + result.push_back(byte); + index += 1; + } + return result; + } + + void WebSocketImplementation::SendTransportBuffer( + std::vector const& sendFrame, + Azure::Core::Context const& context) + { + std::unique_lock transportLock(m_transportMutex); + m_receiveStatistics.BytesSent += static_cast(sendFrame.size()); + m_receiveStatistics.FramesSent += 1; + m_transport->SendBuffer(sendFrame.data(), sendFrame.size(), context); + } + + // Verify the Sec-WebSocket-Accept header as defined in RFC 6455 Section 1.3, which defines + // the opening handshake used for establishing the WebSocket connection. + std::string acceptHeaderGuid = "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"; + void WebSocketImplementation::VerifySocketAccept( + std::string const& encodedKey, + std::string const& acceptHeader) + { + std::string concatenatedKey(encodedKey); + concatenatedKey += acceptHeaderGuid; + Azure::Core::Cryptography::_internal::Sha1Hash sha1hash; + + sha1hash.Append( + reinterpret_cast(concatenatedKey.data()), concatenatedKey.size()); + auto keyHash = sha1hash.Final(); + std::string encodedHash = Azure::Core::Convert::Base64Encode(keyHash); + if (encodedHash != acceptHeader) + { + throw std::runtime_error( + "Hash returned by WebSocket server does not match expected hash. Aborting"); + } + } + + WebSocketImplementation::PingThread::PingThread( + WebSocketImplementation* socketImplementation, + std::chrono::duration pingInterval) + : m_webSocketImplementation(socketImplementation), m_pingInterval(pingInterval) + { + } + void WebSocketImplementation::PingThread::Start(std::shared_ptr transport) + { + m_stop = false; + // Spin up a thread to receive data from the transport. + if (!transport->HasBuiltInWebSocketSupport()) + { + std::unique_lock lock(m_pingThreadStarted); + m_pingThread = std::thread{&PingThread::PingThreadLoop, this}; + m_pingThreadReady.wait(lock); + } + } + + WebSocketImplementation::PingThread::~PingThread() + { + // Ensure that the receive thread is stopped. + Shutdown(); + } + void WebSocketImplementation::PingThread::Shutdown() + { + if (m_pingThread.joinable()) + { + std::unique_lock lock(m_stopMutex); + m_stop = true; + lock.unlock(); + m_pingThreadStopped.notify_all(); + + m_pingThread.join(); + } + } + + void WebSocketImplementation::PingThread::PingThreadLoop() + { + Log::Write(Logger::Level::Verbose, "Start Ping Thread Loop."); + { + std::unique_lock lock(m_pingThreadStarted); + m_pingThreadReady.notify_all(); + } + while (true) + { + std::unique_lock lock(m_stopMutex); + if (this->m_pingThreadStopped.wait_for(lock, m_pingInterval) == std::cv_status::timeout) + { + Log::Write(Logger::Level::Verbose, "Send Ping to peer."); + + // The receiveContext timed out, this means we timed out our "ping" timeout. + // Send a "Ping" request to the remote node. + auto pingData = GenerateRandomBytes(4); + SendPing(pingData, Azure::Core::Context{}); + } + if (m_stop) + { + Log::Write(Logger::Level::Verbose, "Exiting ping thread"); + return; + } + } + } + + bool WebSocketImplementation::PingThread::SendPing( + std::vector const& pingData, + Azure::Core::Context const& context) + { + std::vector pingFrame = EncodeFrame(SocketOpcode::Ping, true, pingData); + m_webSocketImplementation->m_receiveStatistics.PingFramesSent++; + m_webSocketImplementation->SendTransportBuffer(pingFrame, context); + return true; + } + + void WebSocketImplementation::SendPong( + std::vector const& pongData, + Azure::Core::Context const& context) + { + std::vector pongFrame = EncodeFrame(SocketOpcode::Pong, true, pongData); + + m_receiveStatistics.PongFramesSent++; + SendTransportBuffer(pongFrame, context); + } + + // Generator for random bytes. Used in WebSocketImplementation and tests. + std::vector GenerateRandomBytes(size_t vectorSize) + { + std::random_device randomEngine; + + std::vector rv(vectorSize); + std::generate(begin(rv), end(rv), [&randomEngine]() mutable { + return static_cast(randomEngine() % UINT8_MAX); + }); + return rv; + } +}}}}} // namespace Azure::Core::Http::WebSockets::_detail diff --git a/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp b/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp new file mode 100644 index 0000000000..73a10ec143 --- /dev/null +++ b/sdk/core/azure-core/src/http/websockets/websockets_impl.hpp @@ -0,0 +1,372 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/http/websockets/websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/internal/http/pipeline.hpp" +#include +#include +#include +#include +#include + +// Implementation of WebSocket protocol. +namespace Azure { namespace Core { namespace Http { namespace WebSockets { namespace _detail { + + // Generator for random bytes. Used in WebSocketImplementation and tests. + std::vector GenerateRandomBytes(size_t vectorSize); + + class WebSocketImplementation { + enum class SocketState + { + Invalid, + Closed, + Opening, + Open, + Closing, + }; + + public: + WebSocketImplementation( + Azure::Core::Url const& remoteUrl, + _internal::WebSocketOptions const& options); + + void Open(Azure::Core::Context const& context); + void Close( + uint16_t closeStatus, + std::string const& closeReason, + Azure::Core::Context const& context); + void SendFrame( + std::string const& textFrame, + bool isFinalFrame, + Azure::Core::Context const& context); + void SendFrame( + std::vector const& binaryFrame, + bool isFinalFrame, + Azure::Core::Context const& context); + + std::shared_ptr<_internal::WebSocketFrame> ReceiveFrame(Azure::Core::Context const& context); + + void AddHeader(std::string const& headerName, std::string const& headerValue); + + std::string const& GetNegotiatedProtocol(); + bool IsOpen() { return m_state == SocketState::Open; } + bool HasBuiltInWebSocketSupport(); + + _internal::WebSocketStatistics GetStatistics() const; + + private: + // WebSocket opcodes. + enum class SocketOpcode : uint8_t + { + Continuation = 0x00, + TextFrame = 0x01, + BinaryFrame = 0x02, + Close = 0x08, + Ping = 0x09, + Pong = 0x0a + }; + + /** + * Indicates the type of the message currently being processed. Used when processing + * Continuation Opcode frames. + */ + enum class SocketMessageType : int + { + Unknown, + Text, + Binary, + }; + + class WebSocketInternalFrame { + public: + SocketOpcode Opcode{}; + bool IsFinalFrame{false}; + std::vector Payload; + std::exception_ptr Exception; + WebSocketInternalFrame( + SocketOpcode opcode, + bool isFinalFrame, + std::vector const& payload) + : Opcode(opcode), IsFinalFrame(isFinalFrame), Payload(payload) + { + } + WebSocketInternalFrame(std::exception_ptr exception) : Exception(exception) {} + }; + + struct ReceiveStatistics + { + std::atomic FramesSent; + std::atomic FramesReceived; + std::atomic BytesSent; + std::atomic BytesReceived; + std::atomic PingFramesSent; + std::atomic PingFramesReceived; + std::atomic PongFramesSent; + std::atomic PongFramesReceived; + std::atomic TextFramesReceived; + std::atomic BinaryFramesReceived; + std::atomic ContinuationFramesReceived; + std::atomic CloseFramesReceived; + std::atomic UnknownFramesReceived; + std::atomic FramesDropped; + std::atomic FramesDroppedByPayloadSizeLimit; + std::atomic FramesDroppedByProtocolError; + std::atomic TransportReads; + std::atomic TransportReadBytes; + std::atomic BinaryFramesSent; + std::atomic TextFramesSent; + std::atomic FramesDroppedByClose; + + void Reset() + { + FramesSent = 0; + BytesSent = 0; + FramesReceived = 0; + BytesReceived = 0; + PingFramesReceived = 0; + PingFramesSent = 0; + PongFramesReceived = 0; + PongFramesSent = 0; + TextFramesReceived = 0; + TextFramesSent = 0; + BinaryFramesReceived = 0; + BinaryFramesSent = 0; + ContinuationFramesReceived = 0; + CloseFramesReceived = 0; + UnknownFramesReceived = 0; + FramesDropped = 0; + FramesDroppedByClose = 0; + FramesDroppedByPayloadSizeLimit = 0; + FramesDroppedByProtocolError = 0; + TransportReads = 0; + TransportReadBytes = 0; + } + }; + /** + * @brief The PingThread handles sending Ping operations from the WebSocket server. + * + */ + class PingThread { + public: + /** + * @brief Construct a new ReceiveQueue object. + * + * @param webSocketImplementation Parent object, used to send Ping threads. + * @param pingInterval Interval to wait between sending pings. + */ + PingThread( + WebSocketImplementation* webSocketImplementation, + std::chrono::duration pingInterval); + /** + * @brief Destroys a ReceiveQueue object. Blocks until the queue thread is completed. + */ + ~PingThread(); + + /** + * @brief Start the receive queue. This will start a thread that will process incoming frames. + * + * @param transport The websocket transport to use for receiving frames. + */ + void Start(std::shared_ptr transport); + /** + * @brief Stop the receive queue. This will stop the thread that processes incoming frames. + */ + void Shutdown(); + + private: + /** + * @brief The receive queue thread. + */ + void PingThreadLoop(); + /** + * @brief Send a "ping" frame to the other side of the WebSocket. + * + * @returns True if the ping was sent, false if the underlying transport didn't support "Ping" + * operations. + */ + bool SendPing(std::vector const& pingData, Azure::Core::Context const& context); + + WebSocketImplementation* m_webSocketImplementation; + std::chrono::duration m_pingInterval; + std::thread m_pingThread; + std::mutex m_pingThreadStarted; + std::condition_variable m_pingThreadReady; + + std::mutex m_stopMutex; + std::condition_variable m_pingThreadStopped; + bool m_stop = false; + }; + + /** + * @brief Encode a websocket frame according to RFC 6455 section 5.2. + * + * This wire format for the data transfer part is described by the ABNF + * [RFC5234] given in detail in this section. (Note that, unlike in + * other sections of this document, the ABNF in this section is + * operating on groups of bits. The length of each group of bits is + * indicated in a comment. When encoded on the wire, the most + * significant bit is the leftmost in the ABNF). A high-level overview + * of the framing is given in the following figure. In a case of + * conflict between the figure below and the ABNF specified later in + * this section, the figure is authoritative. + * + * 0 1 2 3 + * 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 2 3 4 5 6 7 8 9 0 1 + * +-+-+-+-+-------+-+-------------+-------------------------------+ + * |F|R|R|R| opcode|M| Payload len | Extended payload length | + * |I|S|S|S| (4) |A| (7) | (16/64) | + * |N|V|V|V| |S| | (if payload len==126/127) | + * | |1|2|3| |K| | | + * +-+-+-+-+-------+-+-------------+ - - - - - - - - - - - - - - - + + * | Extended payload length continued, if payload len == 127 | + * + - - - - - - - - - - - - - - - +-------------------------------+ + * | |Masking-key, if MASK set to 1 | + * +-------------------------------+-------------------------------+ + * | Masking-key (continued) | Payload Data | + * +-------------------------------- - - - - - - - - - - - - - - - + + * : Payload Data continued ... : + * + - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + * | Payload Data continued ... | + * +---------------------------------------------------------------+ + * + * FIN: 1 bit + * + * Indicates that this is the final fragment in a message. The first + * fragment MAY also be the final fragment. + * + * RSV1, RSV2, RSV3: 1 bit each + * + * MUST be 0 unless an extension is negotiated that defines meanings + * for non-zero values. If a nonzero value is received and none of + * the negotiated extensions defines the meaning of such a nonzero + * value, the receiving endpoint MUST _Fail the WebSocket + * Connection_. + * + * Opcode: 4 bits + * + * Defines the interpretation of the "Payload data". If an unknown + * opcode is received, the receiving endpoint MUST _Fail the + * WebSocket Connection_. The following values are defined. + * + * * %x0 denotes a continuation frame + * + * * %x1 denotes a text frame + * + * * %x2 denotes a binary frame + * + * * %x3-7 are reserved for further non-control frames + * + * * %x8 denotes a connection close + * + * * %x9 denotes a ping + * + * * %xA denotes a pong + * + * * %xB-F are reserved for further control frames + * + * Mask: 1 bit + * + * Defines whether the "Payload data" is masked. If set to 1, a + * masking key is present in masking-key, and this is used to unmask + * the "Payload data" as per Section 5.3. All frames sent from + * client to server have this bit set to 1. + * + * Payload length: 7 bits, 7+16 bits, or 7+64 bits + * + * The length of the "Payload data", in bytes: if 0-125, that is the + * payload length. If 126, the following 2 bytes interpreted as a + * 16-bit unsigned integer are the payload length. If 127, the + * following 8 bytes interpreted as a 64-bit unsigned integer (the + * most significant bit MUST be 0) are the payload length. Multibyte + * length quantities are expressed in network byte order. Note that + * in all cases, the minimal number of bytes MUST be used to encode + * the length, for example, the length of a 124-byte-long string + * can't be encoded as the sequence 126, 0, 124. The payload length + * is the length of the "Extension data" + the length of the + * "Application data". The length of the "Extension data" may be + * zero, in which case the payload length is the length of the + * "Application data". + * Masking-key: 0 or 4 bytes + * + * All frames sent from the client to the server are masked by a + * 32-bit value that is contained within the frame. This field is + * present if the mask bit is set to 1 and is absent if the mask bit + * is set to 0. See Section 5.3 for further information on client- + * to-server masking. + * + * Payload data: (x+y) bytes + * + * The "Payload data" is defined as "Extension data" concatenated + * with "Application data". + * + * Extension data: x bytes + * + * The "Extension data" is 0 bytes unless an extension has been + * negotiated. Any extension MUST specify the length of the + * "Extension data", or how that length may be calculated, and how + * the extension use MUST be negotiated during the opening handshake. + * If present, the "Extension data" is included in the total payload + * length. + * + * Application data: y bytes + * + * Arbitrary "Application data", taking up the remainder of the frame + * after any "Extension data". The length of the "Application data" + * is equal to the payload length minus the length of the "Extension + * data". + */ + static std::vector EncodeFrame( + SocketOpcode opcode, + bool isFinal, + std::vector const& payload); + + SocketState m_state{SocketState::Invalid}; + + std::vector GenerateRandomKey() { return GenerateRandomBytes(16); }; + void VerifySocketAccept(std::string const& encodedKey, std::string const& acceptHeader); + + /********* + * Buffered Read Support. Read data from the underlying transport into a buffer. + */ + uint8_t ReadTransportByte(Azure::Core::Context const& context); + uint16_t ReadTransportShort(Azure::Core::Context const& context); + uint64_t ReadTransportInt64(Azure::Core::Context const& context); + std::vector ReadTransportBytes(size_t readLength, Azure::Core::Context const& context); + bool IsTransportEof() const { return m_eof; } + void SendPong(std::vector const& pongData, Azure::Core::Context const& context); + void SendTransportBuffer( + std::vector const& payload, + Azure::Core::Context const& context); + std::shared_ptr ReceiveTransportFrame( + Azure::Core::Context const& context); + + /** + * @brief Decode a frame received from the websocket server. + * + * @returns A pointer to the start of the decoded data. + */ + std::shared_ptr DecodeFrame(Azure::Core::Context const& context); + + Azure::Core::Url m_remoteUrl; + _internal::WebSocketOptions m_options; + std::map m_headers; + std::string m_chosenProtocol; + std::shared_ptr m_transport; + PingThread m_pingThread; + SocketMessageType m_currentMessageType{SocketMessageType::Unknown}; + std::mutex m_stateMutex; + std::thread::id m_stateOwner; + + ReceiveStatistics m_receiveStatistics{}; + + std::mutex m_transportMutex; + + std::unique_ptr m_initialBodyStream; + constexpr static size_t m_bufferSize = 1024; + uint8_t m_buffer[m_bufferSize]{}; + size_t m_bufferPos = 0; + size_t m_bufferLen = 0; + bool m_eof = false; + }; +}}}}} // namespace Azure::Core::Http::WebSockets::_detail diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp index 7726305bf9..9131eb98df 100644 --- a/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp +++ b/sdk/core/azure-core/src/http/winhttp/win_http_transport.cpp @@ -20,6 +20,7 @@ using namespace Azure::Core::Http; namespace { const std::string HttpScheme = "http"; +const std::string WebSocketScheme = "ws"; inline std::wstring HttpMethodToWideString(HttpMethod method) { @@ -198,9 +199,8 @@ std::string GetHeadersAsString(Azure::Core::Http::Request const& request) } // namespace -void GetErrorAndThrow(const std::string& exceptionMessage) +void WinHttpTransport::GetErrorAndThrow(const std::string& exceptionMessage, DWORD error) { - DWORD error = GetLastError(); std::string errorMessage = exceptionMessage + " Error Code: " + std::to_string(error); char* errorMsg = nullptr; @@ -226,17 +226,19 @@ void GetErrorAndThrow(const std::string& exceptionMessage) throw Azure::Core::Http::TransportException(errorMessage); } -HINTERNET WinHttpTransport::CreateSessionHandle() +_detail::unique_HINTERNET WinHttpTransport::CreateSessionHandle() { // Use WinHttpOpen to obtain a session handle. // The dwFlags is set to 0 - all WinHTTP functions are performed synchronously. - HINTERNET sessionHandle = WinHttpOpen( - NULL, // Do not use a fallback user-agent string, and only rely on the header within the - // request itself. - WINHTTP_ACCESS_TYPE_NO_PROXY, - WINHTTP_NO_PROXY_NAME, - WINHTTP_NO_PROXY_BYPASS, - 0); + _detail::unique_HINTERNET sessionHandle( + WinHttpOpen( + NULL, // Do not use a fallback user-agent string, and only rely on the header within the + // request itself. + WINHTTP_ACCESS_TYPE_NO_PROXY, + WINHTTP_NO_PROXY_NAME, + WINHTTP_NO_PROXY_BYPASS, + 0), + _detail::HINTERNET_deleter{}); if (!sessionHandle) { @@ -253,19 +255,22 @@ HINTERNET WinHttpTransport::CreateSessionHandle() #ifdef WINHTTP_OPTION_TCP_FAST_OPEN BOOL tcp_fast_open = TRUE; WinHttpSetOption( - sessionHandle, WINHTTP_OPTION_TCP_FAST_OPEN, &tcp_fast_open, sizeof(tcp_fast_open)); + sessionHandle.get(), WINHTTP_OPTION_TCP_FAST_OPEN, &tcp_fast_open, sizeof(tcp_fast_open)); #endif #ifdef WINHTTP_OPTION_TLS_FALSE_START BOOL tls_false_start = TRUE; WinHttpSetOption( - sessionHandle, WINHTTP_OPTION_TLS_FALSE_START, &tls_false_start, sizeof(tls_false_start)); + sessionHandle.get(), + WINHTTP_OPTION_TLS_FALSE_START, + &tls_false_start, + sizeof(tls_false_start)); #endif // Enforce TLS version 1.2 auto tlsOption = WINHTTP_FLAG_SECURE_PROTOCOL_TLS1_2; if (!WinHttpSetOption( - sessionHandle, WINHTTP_OPTION_SECURE_PROTOCOLS, &tlsOption, sizeof(tlsOption))) + sessionHandle.get(), WINHTTP_OPTION_SECURE_PROTOCOLS, &tlsOption, sizeof(tlsOption))) { GetErrorAndThrow("Error while enforcing TLS 1.2 for connection request."); } @@ -278,23 +283,26 @@ WinHttpTransport::WinHttpTransport(WinHttpTransportOptions const& options) { } -void WinHttpTransport::CreateConnectionHandle( - std::unique_ptr<_detail::HandleManager>& handleManager) +_detail::unique_HINTERNET WinHttpTransport::CreateConnectionHandle( + Azure::Core::Url const& url, + Azure::Core::Context const& context) { // If port is 0, i.e. INTERNET_DEFAULT_PORT, it uses port 80 for HTTP and port 443 for HTTPS. - uint16_t port = handleManager->m_request.GetUrl().GetPort(); + uint16_t port = url.GetPort(); - handleManager->m_context.ThrowIfCancelled(); + context.ThrowIfCancelled(); // Specify an HTTP server. // This function always operates synchronously. - handleManager->m_connectionHandle = WinHttpConnect( - m_sessionHandle, - StringToWideString(handleManager->m_request.GetUrl().GetHost()).c_str(), - port == 0 ? INTERNET_DEFAULT_PORT : port, - 0); - - if (!handleManager->m_connectionHandle) + _detail::unique_HINTERNET rv( + WinHttpConnect( + m_sessionHandle.get(), + StringToWideString(url.GetHost()).c_str(), + port == 0 ? INTERNET_DEFAULT_PORT : port, + 0), + _detail::HINTERNET_deleter{}); + + if (!rv) { // Errors include: // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE @@ -306,29 +314,35 @@ void WinHttpTransport::CreateConnectionHandle( // ERROR_NOT_ENOUGH_MEMORY GetErrorAndThrow("Error while getting a connection handle."); } + return rv; } -void WinHttpTransport::CreateRequestHandle(std::unique_ptr<_detail::HandleManager>& handleManager) +_detail::unique_HINTERNET WinHttpTransport::CreateRequestHandle( + _detail::unique_HINTERNET const& connectionHandle, + Azure::Core::Url const& url, + Azure::Core::Http::HttpMethod const& method) { - const std::string& path = handleManager->m_request.GetUrl().GetRelativeUrl(); - HttpMethod requestMethod = handleManager->m_request.GetMethod(); + const std::string& path = url.GetRelativeUrl(); + HttpMethod requestMethod = method; bool const requestSecureHttp( !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( - handleManager->m_request.GetUrl().GetScheme(), HttpScheme)); + url.GetScheme(), HttpScheme) + && !Azure::Core::_internal::StringExtensions::LocaleInvariantCaseInsensitiveEqual( + url.GetScheme(), WebSocketScheme)); // Create an HTTP request handle. - handleManager->m_requestHandle = WinHttpOpenRequest( - handleManager->m_connectionHandle, - HttpMethodToWideString(requestMethod).c_str(), - path.empty() ? NULL - : StringToWideString(path) - .c_str(), // Name of the target resource of the specified HTTP verb - NULL, // Use HTTP/1.1 - WINHTTP_NO_REFERER, - WINHTTP_DEFAULT_ACCEPT_TYPES, // No media types are accepted by the client - requestSecureHttp ? WINHTTP_FLAG_SECURE : 0); // Uses secure transaction semantics (SSL/TLS) - - if (!handleManager->m_requestHandle) + _detail::unique_HINTERNET request( + WinHttpOpenRequest( + connectionHandle.get(), + HttpMethodToWideString(requestMethod).c_str(), + path.empty() ? NULL : StringToWideString(path).c_str(), // Name of the target resource of + // the specified HTTP verb + NULL, // Use HTTP/1.1 + WINHTTP_NO_REFERER, + WINHTTP_DEFAULT_ACCEPT_TYPES, // No media types are accepted by the client + requestSecureHttp ? WINHTTP_FLAG_SECURE : 0), + _detail::HINTERNET_deleter{}); // Uses secure transaction semantics (SSL/TLS) + if (!request) { // Errors include: // ERROR_WINHTTP_INCORRECT_HANDLE_TYPE @@ -348,10 +362,7 @@ void WinHttpTransport::CreateRequestHandle(std::unique_ptr<_detail::HandleManage // Note: If/When TLS client certificate support is added to the pipeline, this line may need to // be revisited. if (!WinHttpSetOption( - handleManager->m_requestHandle, - WINHTTP_OPTION_CLIENT_CERT_CONTEXT, - WINHTTP_NO_CLIENT_CERT_CONTEXT, - 0)) + request.get(), WINHTTP_OPTION_CLIENT_CERT_CONTEXT, WINHTTP_NO_CLIENT_CERT_CONTEXT, 0)) { GetErrorAndThrow("Error while setting client cert context to ignore."); } @@ -360,18 +371,29 @@ void WinHttpTransport::CreateRequestHandle(std::unique_ptr<_detail::HandleManage if (m_options.IgnoreUnknownCertificateAuthority) { auto option = SECURITY_FLAG_IGNORE_UNKNOWN_CA; - if (!WinHttpSetOption( - handleManager->m_requestHandle, WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) + if (!WinHttpSetOption(request.get(), WINHTTP_OPTION_SECURITY_FLAGS, &option, sizeof(option))) { GetErrorAndThrow("Error while setting ignore unknown server certificate."); } } + + // If we are supporting WebSockets, then let WinHTTP know that it should + // prepare to upgrade the HttpRequest to a WebSocket. + if (HasWebSocketSupport() + && !WinHttpSetOption(request.get(), WINHTTP_OPTION_UPGRADE_TO_WEB_SOCKET, nullptr, 0)) + { + GetErrorAndThrow("Error while Enabling WebSocket upgrade."); + } + return request; } // For PUT/POST requests, send additional data using WinHttpWriteData. -void WinHttpTransport::Upload(std::unique_ptr<_detail::HandleManager>& handleManager) +void WinHttpTransport::Upload( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) { - auto streamBody = handleManager->m_request.GetBodyStream(); + auto streamBody = request.GetBodyStream(); int64_t streamLength = streamBody->Length(); // Consider using `MaximumUploadChunkSize` here, after some perf measurements @@ -384,8 +406,7 @@ void WinHttpTransport::Upload(std::unique_ptr<_detail::HandleManager>& handleMan while (true) { - size_t rawRequestLen - = streamBody->Read(unique_buffer.get(), uploadChunkSize, handleManager->m_context); + size_t rawRequestLen = streamBody->Read(unique_buffer.get(), uploadChunkSize, context); if (rawRequestLen == 0) { break; @@ -393,11 +414,11 @@ void WinHttpTransport::Upload(std::unique_ptr<_detail::HandleManager>& handleMan DWORD dwBytesWritten = 0; - handleManager->m_context.ThrowIfCancelled(); + context.ThrowIfCancelled(); // Write data to the server. if (!WinHttpWriteData( - handleManager->m_requestHandle, + requestHandle.get(), unique_buffer.get(), static_cast(rawRequestLen), &dwBytesWritten)) @@ -407,29 +428,32 @@ void WinHttpTransport::Upload(std::unique_ptr<_detail::HandleManager>& handleMan } } -void WinHttpTransport::SendRequest(std::unique_ptr<_detail::HandleManager>& handleManager) +void WinHttpTransport::SendRequest( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) { std::wstring encodedHeaders; int encodedHeadersLength = 0; - auto requestHeaders = handleManager->m_request.GetHeaders(); + auto requestHeaders = request.GetHeaders(); if (requestHeaders.size() != 0) { // The encodedHeaders will be null-terminated and the length is calculated. encodedHeadersLength = -1; - std::string requestHeaderString = GetHeadersAsString(handleManager->m_request); + std::string requestHeaderString = GetHeadersAsString(request); requestHeaderString.append("\0"); encodedHeaders = StringToWideString(requestHeaderString); } - int64_t streamLength = handleManager->m_request.GetBodyStream()->Length(); + int64_t streamLength = request.GetBodyStream()->Length(); - handleManager->m_context.ThrowIfCancelled(); + context.ThrowIfCancelled(); // Send a request. if (!WinHttpSendRequest( - handleManager->m_requestHandle, + requestHandle.get(), requestHeaders.size() == 0 ? WINHTTP_NO_ADDITIONAL_HEADERS : encodedHeaders.c_str(), encodedHeadersLength, WINHTTP_NO_REQUEST_DATA, @@ -468,18 +492,20 @@ void WinHttpTransport::SendRequest(std::unique_ptr<_detail::HandleManager>& hand if (streamLength > 0) { - Upload(handleManager); + Upload(requestHandle, request, context); } } -void WinHttpTransport::ReceiveResponse(std::unique_ptr<_detail::HandleManager>& handleManager) +void WinHttpTransport::ReceiveResponse( + _detail::unique_HINTERNET const& requestHandle, + Azure::Core::Context const& context) { - handleManager->m_context.ThrowIfCancelled(); + context.ThrowIfCancelled(); // Wait to receive the response to the HTTP request initiated by WinHttpSendRequest. // When WinHttpReceiveResponse completes successfully, the status code and response headers have // been received. - if (!WinHttpReceiveResponse(handleManager->m_requestHandle, NULL)) + if (!WinHttpReceiveResponse(requestHandle.get(), NULL)) { // Errors include: // ERROR_WINHTTP_CANNOT_CONNECT @@ -494,7 +520,7 @@ void WinHttpTransport::ReceiveResponse(std::unique_ptr<_detail::HandleManager>& } int64_t WinHttpTransport::GetContentLength( - std::unique_ptr<_detail::HandleManager>& handleManager, + _detail::unique_HINTERNET const& requestHandle, HttpMethod requestMethod, HttpStatusCode responseStatusCode) { @@ -511,7 +537,7 @@ int64_t WinHttpTransport::GetContentLength( if (requestMethod != HttpMethod::Head && responseStatusCode != HttpStatusCode::NoContent) { if (!WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_CONTENT_LENGTH | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &dwContentLength, @@ -530,14 +556,14 @@ int64_t WinHttpTransport::GetContentLength( } std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( - std::unique_ptr<_detail::HandleManager> handleManager, + _detail::unique_HINTERNET& requestHandle, HttpMethod requestMethod) { // First, use WinHttpQueryHeaders to obtain the size of the buffer. // The call is expected to fail since no destination buffer is provided. DWORD sizeOfHeaders = 0; if (WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_RAW_HEADERS, WINHTTP_HEADER_NAME_BY_INDEX, NULL, @@ -563,7 +589,7 @@ std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( // Now, use WinHttpQueryHeaders to retrieve all the headers. // Each header is terminated by "\0". An additional "\0" terminates the list of headers. if (!WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_RAW_HEADERS, WINHTTP_HEADER_NAME_BY_INDEX, outputBuffer.data(), @@ -583,7 +609,7 @@ std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( // Get the HTTP version. if (!WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_VERSION, WINHTTP_HEADER_NAME_BY_INDEX, outputBuffer.data(), @@ -606,7 +632,7 @@ std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( // Get the status code as a number. if (!WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_STATUS_CODE | WINHTTP_QUERY_FLAG_NUMBER, WINHTTP_HEADER_NAME_BY_INDEX, &statusCode, @@ -623,7 +649,7 @@ std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( DWORD sizeOfReasonPhrase = sizeOfHeaders; if (WinHttpQueryHeaders( - handleManager->m_requestHandle, + requestHandle.get(), WINHTTP_QUERY_STATUS_TEXT, WINHTTP_HEADER_NAME_BY_INDEX, outputBuffer.data(), @@ -642,26 +668,32 @@ std::unique_ptr WinHttpTransport::SendRequestAndGetResponse( SetHeaders(responseHeaders, rawResponse); - int64_t contentLength - = GetContentLength(handleManager, requestMethod, rawResponse->GetStatusCode()); + if (HasWebSocketSupport() && (httpStatusCode == HttpStatusCode::SwitchingProtocols)) + { + OnUpgradedConnection(requestHandle); + } + else + { + int64_t contentLength + = GetContentLength(requestHandle, requestMethod, rawResponse->GetStatusCode()); - rawResponse->SetBodyStream( - std::make_unique<_detail::WinHttpStream>(std::move(handleManager), contentLength)); + rawResponse->SetBodyStream( + std::make_unique<_detail::WinHttpStream>(requestHandle, contentLength)); + } return rawResponse; } std::unique_ptr WinHttpTransport::Send(Request& request, Context const& context) { - auto handleManager = std::make_unique<_detail::HandleManager>(request, context); - - CreateConnectionHandle(handleManager); - CreateRequestHandle(handleManager); + _detail::unique_HINTERNET connectionHandle = CreateConnectionHandle(request.GetUrl(), context); + _detail::unique_HINTERNET requestHandle + = CreateRequestHandle(connectionHandle, request.GetUrl(), request.GetMethod()); - SendRequest(handleManager); + SendRequest(requestHandle, request, context); - ReceiveResponse(handleManager); + ReceiveResponse(requestHandle, context); - return SendRequestAndGetResponse(std::move(handleManager), request.GetMethod()); + return SendRequestAndGetResponse(requestHandle, request.GetMethod()); } // Read the response from the sent request. @@ -679,7 +711,7 @@ size_t _detail::WinHttpStream::OnRead(uint8_t* buffer, size_t count, Context con DWORD numberOfBytesRead = 0; if (!WinHttpReadData( - this->m_handleManager->m_requestHandle, + this->m_requestHandle.get(), (LPVOID)(buffer), static_cast(count), &numberOfBytesRead)) diff --git a/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp b/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp new file mode 100644 index 0000000000..7d869ca709 --- /dev/null +++ b/sdk/core/azure-core/src/http/winhttp/win_http_websockets.cpp @@ -0,0 +1,221 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "azure/core/http/http.hpp" +#include "azure/core/http/policies/policy.hpp" +#include "azure/core/http/transport.hpp" +#include "azure/core/http/websockets/win_http_websockets_transport.hpp" +#include "azure/core/internal/diagnostics/log.hpp" +#include "azure/core/platform.hpp" + +#if defined(AZ_PLATFORM_POSIX) +#include // for poll() +#include // for socket shutdown +#elif defined(AZ_PLATFORM_WINDOWS) +#if !defined(WIN32_LEAN_AND_MEAN) +#define WIN32_LEAN_AND_MEAN +#endif +#if !defined(NOMINMAX) +#define NOMINMAX +#endif +#include +#include // for WSAPoll(); +#endif +#include + +namespace Azure { namespace Core { namespace Http { namespace WebSockets { + + void WinHttpWebSocketTransport::OnUpgradedConnection( + Azure::Core::Http::_detail::unique_HINTERNET const& requestHandle) + { + // Convert the request handle into a WebSocket handle for us to use later. + m_socketHandle = Azure::Core::Http::_detail::unique_HINTERNET( + WinHttpWebSocketCompleteUpgrade(requestHandle.get(), 0), + Azure::Core::Http::_detail::HINTERNET_deleter{}); + if (!m_socketHandle) + { + GetErrorAndThrow("Error Upgrading HttpRequest handle to WebSocket handle."); + } + } + + std::unique_ptr WinHttpWebSocketTransport::Send( + Azure::Core::Http::Request& request, + Azure::Core::Context const& context) + { + return WinHttpTransport::Send(request, context); + } + + /** + * @brief Close the WebSocket cleanly. + */ + void WinHttpWebSocketTransport::Close() { m_socketHandle.reset(); } + + // Native WebSocket support methods. + /** + * @brief Gracefully closes the WebSocket, notifying the remote node of the close reason. + * + * @details Not implemented for CURL websockets because CURL does not support native websockets. + * + * @param status Status value to be sent to the remote node. Application defined. + * @param disconnectReason UTF-8 encoded reason for the disconnection. Optional. + * @param context Context for the operation. + * + */ + void WinHttpWebSocketTransport::NativeCloseSocket( + uint16_t status, + std::string const& disconnectReason, + Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + + auto err = WinHttpWebSocketClose( + m_socketHandle.get(), + status, + disconnectReason.empty() + ? nullptr + : reinterpret_cast(const_cast(disconnectReason.c_str())), + static_cast(disconnectReason.size())); + if (err != 0) + { + GetErrorAndThrow("WinHttpWebSocketClose() failed", err); + } + + context.ThrowIfCancelled(); + + // Make sure that the server responds gracefully to the close request. + auto closeInformation = NativeGetCloseSocketInformation(context); + + // The server should return the same status we sent. + if (closeInformation.CloseReason != status) + { + throw std::runtime_error( + "Close status mismatch, got " + std::to_string(closeInformation.CloseReason) + + " expected " + std::to_string(status)); + } + } + /** + * @brief Retrieve the information associated with a WebSocket close response. + * + * Should only be called when a Receive operation returns WebSocketFrameType::CloseFrameType + * + * @param context Context for the operation. + * + * @returns a tuple containing the status code and string. + */ + WinHttpWebSocketTransport::NativeWebSocketCloseInformation + WinHttpWebSocketTransport::NativeGetCloseSocketInformation(Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + uint16_t closeStatus = 0; + char closeReason[WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH]{}; + DWORD closeReasonLength; + + auto err = WinHttpWebSocketQueryCloseStatus( + m_socketHandle.get(), + &closeStatus, + closeReason, + WINHTTP_WEB_SOCKET_MAX_CLOSE_REASON_LENGTH, + &closeReasonLength); + if (err != 0) + { + GetErrorAndThrow("WinHttpGetCloseStatus() failed", err); + } + return NativeWebSocketCloseInformation{closeStatus, std::string(closeReason)}; + } + + /** + * @brief Send a frame of data to the remote node. + * + * @details Not implemented for CURL websockets because CURL does not support native + * websockets. + * + * @brief frameType Frame type sent to the server, Text or Binary. + * @brief frameData Frame data to be sent to the server. + */ + void WinHttpWebSocketTransport::NativeSendFrame( + NativeWebSocketFrameType frameType, + std::vector const& frameData, + Azure::Core::Context const& context) + { + context.ThrowIfCancelled(); + WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType; + switch (frameType) + { + case NativeWebSocketFrameType::Text: + bufferType = WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::Binary: + bufferType = WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::BinaryFragment: + bufferType = WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE; + break; + case NativeWebSocketFrameType::TextFragment: + bufferType = WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE; + break; + default: + throw std::runtime_error( + "Unknown frame type: " + std::to_string(static_cast(frameType))); + break; + } + // Lock the socket to prevent concurrent writes. WinHTTP gets annoyed if + // there are multiple WinHttpWebSocketSend requests outstanding. + std::lock_guard lock(m_sendMutex); + auto err = WinHttpWebSocketSend( + m_socketHandle.get(), + bufferType, + reinterpret_cast(const_cast(frameData.data())), + static_cast(frameData.size())); + if (err != 0) + { + GetErrorAndThrow("WinHttpWebSocketSend() failed", err); + } + } + + WinHttpWebSocketTransport::NativeWebSocketReceiveInformation + WinHttpWebSocketTransport::NativeReceiveFrame(Azure::Core::Context const& context) + { + WINHTTP_WEB_SOCKET_BUFFER_TYPE bufferType; + NativeWebSocketFrameType frameTypeReceived; + DWORD bufferBytesRead; + std::vector buffer(128); + context.ThrowIfCancelled(); + std::lock_guard lock(m_receiveMutex); + + auto err = WinHttpWebSocketReceive( + m_socketHandle.get(), + reinterpret_cast(buffer.data()), + static_cast(buffer.size()), + &bufferBytesRead, + &bufferType); + if (err != 0 && err != ERROR_INSUFFICIENT_BUFFER) + { + GetErrorAndThrow("WinHttpWebSocketReceive() failed", err); + } + buffer.resize(bufferBytesRead); + + switch (bufferType) + { + case WINHTTP_WEB_SOCKET_UTF8_MESSAGE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Text; + break; + case WINHTTP_WEB_SOCKET_BINARY_MESSAGE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Binary; + break; + case WINHTTP_WEB_SOCKET_BINARY_FRAGMENT_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::BinaryFragment; + break; + case WINHTTP_WEB_SOCKET_UTF8_FRAGMENT_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::TextFragment; + break; + case WINHTTP_WEB_SOCKET_CLOSE_BUFFER_TYPE: + frameTypeReceived = NativeWebSocketFrameType::Closed; + break; + default: + throw std::runtime_error("Unknown frame type: " + std::to_string(bufferType)); + break; + } + return NativeWebSocketReceiveInformation{frameTypeReceived, buffer}; + } + +}}}} // namespace Azure::Core::Http::WebSockets diff --git a/sdk/core/azure-core/test/ut/CMakeLists.txt b/sdk/core/azure-core/test/ut/CMakeLists.txt index bd75517ce3..ac0023446b 100644 --- a/sdk/core/azure-core/test/ut/CMakeLists.txt +++ b/sdk/core/azure-core/test/ut/CMakeLists.txt @@ -80,7 +80,8 @@ add_executable ( transport_adapter_implementation_test.cpp url_test.cpp uuid_test.cpp -) + websocket_test.cpp + ) if (MSVC) # Disable warnings: @@ -98,6 +99,17 @@ if (MSVC) target_compile_options(azure-core-test PUBLIC /wd26495 /wd26812 /wd6326 /wd28204 /wd28020 /wd6330 /wd4389) endif() +# Additional test files to be copied to the output directory. +set(TEST_ADDITIONAL_FILES ${CMAKE_CURRENT_LIST_DIR}/websocket_server.py + ${CMAKE_CURRENT_LIST_DIR}/requirements.txt + ${CMAKE_CURRENT_LIST_DIR}/Start-WebSocketServer.ps1 +) + +add_custom_command(TARGET azure-core-test POST_BUILD + COMMAND ${CMAKE_COMMAND} -E copy ${TEST_ADDITIONAL_FILES} ${CMAKE_CURRENT_BINARY_DIR} + DEPENDS ${TEST_ADDITIONAL_FILES} + COMMENT 'Copying non-source output files') + # Adding private headers from CORE to the tests so we can test the private APIs with no relative paths include. target_include_directories (azure-core-test PRIVATE $) diff --git a/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 b/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 new file mode 100644 index 0000000000..7e08280988 --- /dev/null +++ b/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 @@ -0,0 +1,13 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-License-Identifier: MIT +param( + [string] $LogFileLocation = "$($env:BUILD_SOURCESDIRECTORY)/WebSocketServer.log" +) + +if ($IsWindows) { + Start-Process 'python.exe' ` + -ArgumentList 'websocket_server.py' ` + -NoNewWindow -PassThru -RedirectStandardOutput $LogFileLocation +} else { + Start-Process nohup 'python3 websocket_server.py' -RedirectStandardOutput $LogFileLocation +} diff --git a/sdk/core/azure-core/test/ut/requirements.txt b/sdk/core/azure-core/test/ut/requirements.txt new file mode 100644 index 0000000000..14774b465e --- /dev/null +++ b/sdk/core/azure-core/test/ut/requirements.txt @@ -0,0 +1 @@ +websockets diff --git a/sdk/core/azure-core/test/ut/sha_test.cpp b/sdk/core/azure-core/test/ut/sha_test.cpp index 32bb33a26e..826f805503 100644 --- a/sdk/core/azure-core/test/ut/sha_test.cpp +++ b/sdk/core/azure-core/test/ut/sha_test.cpp @@ -7,38 +7,114 @@ using namespace Azure::Core::Cryptography::_internal; +// cspell: words ABCDE FGHIJ +TEST(SHA, SHA1Test) +{ + { + Sha1Hash sha; + Sha1Hash sha2; + uint8_t data[] = "A"; + auto shaResult = sha.Final(data, sizeof(data)); + auto shaResult2 = sha2.Final(data, sizeof(data)); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } + { + Sha1Hash sha; + Sha1Hash sha2; + std::string data1 = "ABCDE"; + std::string data2 = "FGHIJ"; + sha.Append(reinterpret_cast(data1.data()), data1.size()); + auto shaResult = sha.Final(reinterpret_cast(data2.data()), data2.size()); + auto shaResult2 = sha2.Final( + reinterpret_cast((data1 + data2).data()), data1.size() + data2.size()); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } +} + TEST(SHA, SHA256Test) { - Sha256Hash sha; - Sha256Hash sha2; - uint8_t data[] = "A"; - auto shaResult = sha.Final(data, sizeof(data)); - auto shaResult2 = sha2.Final(data, sizeof(data)); - EXPECT_EQ(shaResult, shaResult2); - for (size_t i = 0; i != shaResult.size(); i++) - printf("%02x", shaResult[i]); + { + + Sha256Hash sha; + Sha256Hash sha2; + uint8_t data[] = "A"; + auto shaResult = sha.Final(data, sizeof(data)); + auto shaResult2 = sha2.Final(data, sizeof(data)); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } + { + Sha256Hash sha; + Sha256Hash sha2; + std::string data1 = "ABCDE"; + std::string data2 = "FGHIJ"; + sha.Append(reinterpret_cast(data1.data()), data1.size()); + auto shaResult = sha.Final(reinterpret_cast(data2.data()), data2.size()); + auto shaResult2 = sha2.Final( + reinterpret_cast((data1 + data2).data()), data1.size() + data2.size()); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } } TEST(SHA, SHA384Test) { - Sha384Hash sha; - Sha384Hash sha2; - uint8_t data[] = "A"; - auto shaResult = sha.Final(data, sizeof(data)); - auto shaResult2 = sha2.Final(data, sizeof(data)); - EXPECT_EQ(shaResult, shaResult2); - for (size_t i = 0; i != shaResult.size(); i++) - printf("%02x", shaResult[i]); + { + + Sha384Hash sha; + Sha384Hash sha2; + uint8_t data[] = "A"; + auto shaResult = sha.Final(data, sizeof(data)); + auto shaResult2 = sha2.Final(data, sizeof(data)); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } + { + Sha384Hash sha; + Sha384Hash sha2; + std::string data1 = "ABCDE"; + std::string data2 = "FGHIJ"; + sha.Append(reinterpret_cast(data1.data()), data1.size()); + auto shaResult = sha.Final(reinterpret_cast(data2.data()), data2.size()); + auto shaResult2 = sha2.Final( + reinterpret_cast((data1 + data2).data()), data1.size() + data2.size()); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } } TEST(SHA, SHA512Test) { - Sha512Hash sha; - Sha512Hash sha2; - uint8_t data[] = "A"; - auto shaResult = sha.Final(data, sizeof(data)); - auto shaResult2 = sha2.Final(data, sizeof(data)); - EXPECT_EQ(shaResult, shaResult2); - for (size_t i = 0; i != shaResult.size(); i++) - printf("%02x", shaResult[i]); + { + + Sha512Hash sha; + Sha512Hash sha2; + uint8_t data[] = "A"; + auto shaResult = sha.Final(data, sizeof(data)); + auto shaResult2 = sha2.Final(data, sizeof(data)); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } + { + Sha512Hash sha; + Sha512Hash sha2; + std::string data1 = "ABCDE"; + std::string data2 = "FGHIJ"; + sha.Append(reinterpret_cast(data1.data()), data1.size()); + auto shaResult = sha.Final(reinterpret_cast(data2.data()), data2.size()); + auto shaResult2 = sha2.Final( + reinterpret_cast((data1 + data2).data()), data1.size() + data2.size()); + EXPECT_EQ(shaResult, shaResult2); + for (size_t i = 0; i != shaResult.size(); i++) + printf("%02x", shaResult[i]); + } } diff --git a/sdk/core/azure-core/test/ut/websocket_server.py b/sdk/core/azure-core/test/ut/websocket_server.py new file mode 100644 index 0000000000..ca7bcc077b --- /dev/null +++ b/sdk/core/azure-core/test/ut/websocket_server.py @@ -0,0 +1,155 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# SPDX-License-Identifier: MIT +from array import array +import asyncio +from operator import length_hint +import threading +from time import sleep +from urllib.parse import ParseResult, urlparse + +import websockets + +# create handler for each connection +customPaths = {} +stop = False + +async def handleControlPath(websocket): + while (1): + data : str = await websocket.recv() + parsedCommand = data.split(' ') + if (parsedCommand[0] == "close"): + print("Closing control channel") + await websocket.send("ok") + print("Terminating WebSocket server.") + stop.set_result(0) + break + elif parsedCommand[0] == "newPath": + print("Add path") + newPath = parsedCommand[1] + print(" Add path ", newPath) + customPaths[newPath] = {"path": newPath, "delay": int(parsedCommand[2]) } + await websocket.send("ok") + else: + print("Unknown command, echoing it.") + await websocket.send(data) + +async def handleCustomPath(websocket, path:dict): + print("Handle custom path", path) + data : str = await websocket.recv() + print("Received ", data) + if ("delay" in path.keys()): + sleep(path["delay"]) + print("Responding") + await websocket.send(data) + await websocket.close() + +def HexEncode(data: bytes)->str: + rv="" + for val in data: + rv+= '{:02X}'.format(val) + return rv + +def ParseQuery(url : ParseResult) -> dict: + rv={} + if len(url.query)!=0: + args = url.query.split('&') + for arg in args: + vals=arg.split('=') + rv[vals[0]]=vals[1] + return rv + +echo_count_lock = threading.Lock() +echo_count_recv = 0 +echo_count_send = 0 +client_count = 0 +async def handleEcho(websocket, url:ParseResult): + global client_count + global echo_count_recv + global echo_count_send + global echo_count_lock + queryValues = ParseQuery(url) + while websocket.open: + try: + data = await websocket.recv() + with echo_count_lock: + echo_count_recv+=1 + if 'delay' in queryValues: + print(f"sleeping for {queryValues['delay']} seconds") + await asyncio.sleep(float(queryValues['delay'])) + print("woken up.") + + if 'fragment' in queryValues and queryValues['fragment']=='true': + await websocket.send(data.split()) + else: + await websocket.send(data) + with echo_count_lock: + echo_count_send+=1 + except websockets.ConnectionClosedOK: + print("Connection closed ok.") + with echo_count_lock: + client_count -= 1 + print(f"Echo count: {echo_count_recv}, {echo_count_send} client_count {client_count}") + if client_count == 0: + echo_count_send = 0 + echo_count_recv = 0 + return + except websockets.ConnectionClosed as ex: + if (ex.rcvd): + print(f"Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}") + else: + print(f"Connection closed. No close information.") + with echo_count_lock: + client_count -= 1 + print(f"Echo count: recv: {echo_count_recv}, send: {echo_count_send} client_count {client_count}") + if client_count == 0: + echo_count_send = 0 + echo_count_recv = 0 + return + +async def handler(websocket, path : str): + global client_count + print("Socket handler: ", path) + parsedUrl = urlparse(path) + if (parsedUrl.path == '/openclosetest'): + print("Open/Close Test") + try: + data = await websocket.recv() + print(f"OpenCloseTest: Received {data}") + except websockets.ConnectionClosedOK: + print("OpenCloseTest: Connection closed ok.") + except websockets.ConnectionClosed as ex: + print(f"OpenCloseTest: Connection closed exception: {ex.rcvd.code} {ex.rcvd.reason}") + return + elif (parsedUrl.path == '/echotest'): + with echo_count_lock: + client_count+= 1 + await handleEcho(websocket, parsedUrl) + elif (parsedUrl.path == '/closeduringecho'): + data = await websocket.recv() + await websocket.close(1001, 'closed') + elif (parsedUrl.path =='/control'): + await handleControlPath(websocket) + elif (parsedUrl.path in customPaths.keys()): + print("Found path ", path, "in control paths.") + await handleCustomPath(websocket, customPaths[path]) + elif (parsedUrl.path == '/terminateserver'): + print("Terminating WebSocket server.") + stop.set_result(0) + else: + data = await websocket.recv() + print("Received: ", data) + + reply = f"Data received as: {data}!" + await websocket.send(reply) + +async def main(): + global stop + print("Starting server") + loop = asyncio.get_running_loop() + stop = loop.create_future() + async with websockets.serve(handler, "localhost", 8000, ping_interval=7): + await stop # run forever. + +if __name__=="__main__": + asyncio.run(main()) + print("Ending server") diff --git a/sdk/core/azure-core/test/ut/websocket_test.cpp b/sdk/core/azure-core/test/ut/websocket_test.cpp new file mode 100644 index 0000000000..3160934962 --- /dev/null +++ b/sdk/core/azure-core/test/ut/websocket_test.cpp @@ -0,0 +1,875 @@ +// Copyright (c) Microsoft Corporation. All rights reserved. +// SPDX-License-Identifier: MIT + +#include "../../src/http/websockets/websockets_impl.hpp" +#include "azure/core/http/websockets/websockets.hpp" +#include "azure/core/internal/json/json.hpp" +#include +#include +#include +#include +#include +#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +#include "azure/core/http/websockets/curl_websockets_transport.hpp" +#endif +// cspell::words closeme flibbityflobbidy + +using namespace Azure::Core; +using namespace Azure::Core::Http::WebSockets; +using namespace Azure::Core::Http::WebSockets::_internal; +using namespace std::chrono_literals; + +constexpr uint16_t UndefinedButLegalCloseReason = 4500; + +class WebSocketTests : public testing::Test { +private: +protected: + // Create + static void SetUpTestSuite() {} + static void TearDownTestSuite() {} +}; + +TEST_F(WebSocketTests, CreateSimpleSocket) +{ + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000")); + defaultSocket.AddHeader("newHeader", "headerValue"); + EXPECT_THROW(defaultSocket.GetNegotiatedProtocol(), std::runtime_error); + } +} + +TEST_F(WebSocketTests, OpenSimpleSocket) +{ + { + WebSocketOptions options; + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest"), options); + defaultSocket.AddHeader("newHeader", "headerValue"); + + defaultSocket.Open(); + + EXPECT_THROW(defaultSocket.AddHeader("newHeader", "headerValue"), std::runtime_error); + + // Close the socket without notifying the peer. + defaultSocket.Close(); + } + + { + WebSocketOptions options; + WebSocket defaultSocket(Azure::Core::Url("http://www.microsoft.com/"), options); + defaultSocket.AddHeader("newHeader", "headerValue"); + + // When running this test locally, the call times out, so drop in a 5 second timeout on + // the request. + Azure::Core::Context requestContext = Azure::Core::Context::ApplicationContext.WithDeadline( + std::chrono::system_clock::now() + 5s); + EXPECT_THROW(defaultSocket.Open(requestContext), std::runtime_error); + } +} + +TEST_F(WebSocketTests, OpenAndCloseSocket) +{ + if (false) + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest")); + defaultSocket.AddHeader("newHeader", "headerValue"); + + defaultSocket.Open(); + + // Close the socket without notifying the peer. + defaultSocket.Close(UndefinedButLegalCloseReason); + } + + { + WebSocket defaultSocket(Azure::Core::Url("http://localhost:8000/openclosetest")); + + defaultSocket.Open(); + + // Close the socket without notifying the peer. + defaultSocket.Close(UndefinedButLegalCloseReason, "This is a good reason."); + + // + // Now re-open the socket - this should work to reset everything. + defaultSocket.Open(); + EXPECT_THROW(defaultSocket.Open(), std::runtime_error); + defaultSocket.Close(); + } +} + +TEST_F(WebSocketTests, SimpleEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + testSocket.SendFrame("Test message", true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::TextFrameReceived, response->FrameType); + EXPECT_THROW(response->AsBinaryFrame(), std::logic_error); + auto textResult = response->AsTextFrame(); + EXPECT_EQ("Test message", textResult->Text); + + // Close the socket gracefully. + testSocket.Close(); + } + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?delay=5")); + + testSocket.Open(); + + std::vector binaryData{1, 2, 3, 4, 5, 6}; + + testSocket.SendFrame(binaryData, true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + EXPECT_THROW(response->AsPeerCloseFrame(), std::logic_error); + EXPECT_THROW(response->AsTextFrame(), std::logic_error); + auto textResult = response->AsBinaryFrame(); + EXPECT_EQ(binaryData, textResult->Data); + + // Close the socket gracefully. + testSocket.Close(); + } + + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest?fragment=true&delay=5")); + + testSocket.Open(); + + std::vector binaryData{1, 2, 3, 4, 5, 6}; + + testSocket.SendFrame(binaryData, true); + + std::vector responseData; + std::shared_ptr response; + do + { + response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + responseData.insert(responseData.end(), binaryResult->Data.begin(), binaryResult->Data.end()); + } while (!response->IsFinalFrame); + + auto textResult = response->AsBinaryFrame(); + EXPECT_EQ(binaryData, responseData); + + // Close the socket gracefully. + testSocket.Close(); + } +} + +template void EchoRandomData(WebSocket& socket) +{ + std::vector sendData = Azure::Core::Http::WebSockets::_detail::GenerateRandomBytes(N); + + socket.SendFrame(sendData, true); + + std::vector receiveData; + + std::shared_ptr response; + do + { + response = socket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + receiveData.insert(receiveData.end(), binaryResult->Data.begin(), binaryResult->Data.end()); + } while (!response->IsFinalFrame); + + // Make sure we get back the data we sent in the echo request. + EXPECT_EQ(sendData.size(), receiveData.size()); + EXPECT_EQ(sendData, receiveData); +} + +TEST_F(WebSocketTests, VariableSizeEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + { + EchoRandomData<100>(testSocket); + EchoRandomData<124>(testSocket); + EchoRandomData<125>(testSocket); + // The websocket protocol treats lengths of 125, 126 and > 127 specially. + EchoRandomData<126>(testSocket); + EchoRandomData<127>(testSocket); + EchoRandomData<128>(testSocket); + EchoRandomData<1020>(testSocket); // 1K-4 + EchoRandomData<1021>(testSocket); // 1K-3 + EchoRandomData<1022>(testSocket); // 1K-2 + EchoRandomData<1023>(testSocket); // 1K-1 + EchoRandomData<1024>(testSocket); // 1K + EchoRandomData<2048>(testSocket); // 2K + EchoRandomData<4096>(testSocket); // 4K + EchoRandomData<8192>(testSocket); // 8K + // The websocket protocol treats lengths of >65536 specially. + EchoRandomData<65535>(testSocket); // 64K-1 + EchoRandomData<65536>(testSocket); // 64K + EchoRandomData<65537>(testSocket); // 64K+1 + EchoRandomData<131072>(testSocket); // 128K + } + // Close the socket gracefully. + testSocket.Close(); + } +} + +// Generator for random bytes. Used in WebSocketImplementation and tests. +std::vector GenerateRandomBytes(size_t index, size_t vectorSize) +{ + std::random_device randomEngine; + + std::vector rv(vectorSize + 4); + rv[0] = index & 0xff; + rv[1] = (index >> 8) & 0xff; + rv[2] = (index >> 16) & 0xff; + rv[3] = (index >> 24) & 0xff; + std::generate(std::begin(rv) + 4, std::end(rv), [&randomEngine]() mutable { + return static_cast(randomEngine() % UINT8_MAX); + }); + return rv; +} + +TEST_F(WebSocketTests, CloseDuringEcho) +{ + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho")); + + testSocket.Open(); + + testSocket.SendFrame("Test message", true); + + auto response = testSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::PeerClosedReceived, response->FrameType); + auto PeerClosedReceived = response->AsPeerCloseFrame(); + EXPECT_EQ(1001, PeerClosedReceived->RemoteStatusCode); + + // Close the socket gracefully. + testSocket.Close(); + } + + // Close the websocket while a thread is waiting for a response. + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/echotest?delay=10")); + + testSocket.Open(); + + std::thread testThread([&]() { + try + { + std::vector sendData = GenerateRandomBytes(0, 100); + testSocket.SendFrame(sendData); + GTEST_LOG_(INFO) << "Receive frame."; + auto response = testSocket.ReceiveFrame(); + GTEST_LOG_(INFO) << "Received frame."; + if (response->FrameType == WebSocketFrameType::PeerClosedReceived) + { + GTEST_LOG_(INFO) << "Peer closed the socket; Terminating thread."; + return; + } + else if (response->FrameType != WebSocketFrameType::BinaryFrameReceived) + { + GTEST_LOG_(INFO) << "Unexpected frame type received."; + } + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() + << " Current Thread: " << std::this_thread::get_id() << std::endl; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + } + }); + + std::this_thread::sleep_for(100ms); + + // Close the socket gracefully. + GTEST_LOG_(INFO) << "Closing Socket."; + EXPECT_NO_THROW(testSocket.Close(UndefinedButLegalCloseReason, "Close Reason.")); + GTEST_LOG_(INFO) << "Closed Socket."; + testThread.join(); + } +} + +TEST_F(WebSocketTests, ExpectThrow) +{ + { + WebSocket testSocket(Azure::Core::Url("ws://localhost:8000/closeduringecho")); + + EXPECT_THROW(testSocket.SendFrame("Foo", true), std::runtime_error); + std::vector data{1, 2, 3, 4}; + EXPECT_THROW(testSocket.SendFrame(data, true), std::runtime_error); + EXPECT_THROW(testSocket.ReceiveFrame(), std::runtime_error); + } +} + +std::string ToHexString(std::vector const& data) +{ + std::stringstream ss; + for (auto const& byte : data) + { + ss << std::hex << std::setfill('0') << std::setw(2) << static_cast(byte); + } + return ss.str(); +} + +TEST_F(WebSocketTests, PingReceiveTest) +{ + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + if (!testSocket.HasBuiltInWebSocketSupport()) + { + + GTEST_LOG_(INFO) << "Sleeping for 15 seconds to collect pings."; + Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{std::chrono::system_clock::now() + 15s}); + EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException); + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + EXPECT_NE(0, statistics.PingFramesReceived); + EXPECT_NE(0, statistics.PongFramesSent); + } +} + +TEST_F(WebSocketTests, PingSendTest) +{ + // Configure the socket to ping every second. + WebSocketOptions socketOptions; + socketOptions.PingInterval = std::chrono::seconds(1); + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest"), socketOptions); + + testSocket.Open(); + if (!testSocket.HasBuiltInWebSocketSupport()) + { + + GTEST_LOG_(INFO) << "Sleeping for 10 seconds to collect pings."; + // Note that we cannot collect incoming pings or outgoing pongs unless we are receiving + // data from the server. + Azure::Core::Context receiveContext = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{std::chrono::system_clock::now() + 10s}); + EXPECT_THROW(testSocket.ReceiveFrame(receiveContext), Azure::Core::OperationCancelledException); + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + EXPECT_NE(0, statistics.PingFramesSent); + EXPECT_NE(0, statistics.PongFramesReceived); + EXPECT_NE(0, statistics.PingFramesReceived); + EXPECT_NE(0, statistics.PongFramesSent); + } +} + +TEST_F(WebSocketTests, MultiThreadedTestOnSingleSocket) +{ + constexpr size_t threadCount = 50; + constexpr size_t testDataLength = 200000; + constexpr size_t testDataSize = 100; + constexpr auto testDuration = 10s; + + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + // seed test data for the operations. + std::vector> testData(testDataLength); + std::vector> receivedData(testDataLength); + std::atomic_size_t iterationCount(0); + + // Spin up threadCount threads and hammer the echo server for 10 seconds. + std::vector threads; + std::atomic_int32_t cancellationExceptions{0}; + std::atomic_int32_t exceptions{0}; + for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1) + { + threads.push_back(std::thread([&]() { + std::chrono::time_point startTime + = std::chrono::system_clock::now(); + // Set the context to expire *after* the test is supposed to finish. + Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{startTime} + testDuration + 10s); + size_t iteration = 0; + try + { + do + { + iteration = iterationCount++; + std::vector sendData = GenerateRandomBytes(iteration, testDataSize); + { + if (iteration < testData.size()) + { + if (testData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl; + } + EXPECT_EQ(0, testData[iteration].size()); + testData[iteration] = sendData; + } + } + + testSocket.SendFrame(sendData, true /*, context*/); + auto response = testSocket.ReceiveFrame(context); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + + // Make sure we get back the data we sent in the echo request. + if (binaryResult->Data.size() == 0) + { + GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl; + } + EXPECT_EQ(sendData.size(), binaryResult->Data.size()); + { + // There is no ordering expectation on the results, so we just remember the data + // as it comes in. We'll make sure we received everything later on. + if (iteration < receivedData.size()) + { + if (receivedData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration + << std::endl; + } + + EXPECT_EQ(0, receivedData[iteration].size()); + receivedData[iteration] = binaryResult->Data; + } + } + } while (std::chrono::system_clock::now() - startTime < testDuration); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration + << " Current Thread: " << std::this_thread::get_id() << std::endl; + cancellationExceptions++; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + exceptions++; + } + })); + } + + // Wait for all the threads to exit. + for (auto& thread : threads) + { + thread.join(); + } + + // We no longer need to worry about synchronization since all the worker threads are done. + GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl; + GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex + << testData.size() << ")" << std::endl; + EXPECT_GE(testDataLength, iterationCount.load()); + + auto statistics = testSocket.GetStatistics(); + GTEST_LOG_(INFO) << "Total bytes sent: " << std::dec << statistics.BytesSent; + GTEST_LOG_(INFO) << "Total bytes received: " << std::dec << statistics.BytesReceived; + GTEST_LOG_(INFO) << "Ping Frames received: " << std::dec << statistics.PingFramesReceived; + GTEST_LOG_(INFO) << "Ping Frames sent: " << std::dec << statistics.PingFramesSent; + GTEST_LOG_(INFO) << "Pong Frames received: " << std::dec << statistics.PongFramesReceived; + GTEST_LOG_(INFO) << "Pong Frames sent: " << std::dec << statistics.PongFramesSent; + GTEST_LOG_(INFO) << "Binary frames sent: " << std::dec << statistics.BinaryFramesSent; + GTEST_LOG_(INFO) << "Binary frames received: " << std::dec << statistics.BinaryFramesReceived; + GTEST_LOG_(INFO) << "Total frames lost: " << std::dec << statistics.FramesDropped; + GTEST_LOG_(INFO) << "Transport Reads " << std::dec << statistics.TransportReads; + GTEST_LOG_(INFO) << "Transport Bytes Read " << std::dec << statistics.TransportReadBytes; + + // Close the socket gracefully. + testSocket.Close(); + + EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesSent); + EXPECT_EQ(iterationCount.load(), statistics.BinaryFramesReceived); + + // Resize the test data to the number of actual iterations. + testData.resize(iterationCount.load()); + receivedData.resize(iterationCount.load()); + + // If we've processed every iteration, let's make sure that we received everything we sent. + // If we dropped some results, then we can't check to ensure that we have received everything + // because we can't account for everything sent. + std::multiset testDataStrings; + std::multiset receivedDataStrings; + for (auto const& data : testData) + { + testDataStrings.emplace(ToHexString(data)); + } + for (auto const& data : receivedData) + { + receivedDataStrings.emplace(ToHexString(data)); + } + + EXPECT_EQ(testDataStrings, receivedDataStrings); + for (auto const& data : testDataStrings) + { + if (receivedDataStrings.count(data) != testDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data)); + } + for (auto const& data : receivedDataStrings) + { + if (testDataStrings.count(data) != receivedDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + + EXPECT_NE(testDataStrings.end(), testDataStrings.find(data)); + } + + // We shouldn't have seen any exceptions during the run. + EXPECT_EQ(0, exceptions.load()); + EXPECT_EQ(0, cancellationExceptions.load()); +} + +TEST_F(WebSocketTests, MultiThreadedTestOnMultipleSockets) +{ + constexpr size_t threadCount = 50; + constexpr size_t testDataLength = 200000; + constexpr size_t testDataSize = 100; + constexpr auto testDuration = 10s; + + // seed test data for the operations. + std::vector> testData(testDataLength); + std::vector> receivedData(testDataLength); + std::atomic_size_t iterationCount(0); + + // Spin up threadCount threads and hammer the echo server for 10 seconds. + std::vector threads; + std::atomic_int32_t cancellationExceptions{0}; + std::atomic_int32_t exceptions{0}; + for (size_t threadIndex = 0; threadIndex < threadCount; threadIndex += 1) + { + threads.push_back(std::thread([&]() { + std::chrono::time_point startTime + = std::chrono::system_clock::now(); + // Set the context to expire *after* the test is supposed to finish. + Azure::Core::Context context = Azure::Core::Context::ApplicationContext.WithDeadline( + Azure::DateTime{startTime} + testDuration + 10s); + size_t iteration = 0; + try + { + WebSocket testSocket(Azure::Core::Url("http://localhost:8000/echotest")); + + testSocket.Open(); + + do + { + iteration = iterationCount++; + std::vector sendData = GenerateRandomBytes(iteration, testDataSize); + { + if (iteration < testData.size()) + { + if (testData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting send frame at offset " << iteration << std::endl; + } + EXPECT_EQ(0, testData[iteration].size()); + testData[iteration] = sendData; + } + } + + testSocket.SendFrame(sendData, true /*, context*/); + auto response = testSocket.ReceiveFrame(context); + EXPECT_EQ(WebSocketFrameType::BinaryFrameReceived, response->FrameType); + auto binaryResult = response->AsBinaryFrame(); + + // Make sure we get back the data we sent in the echo request. + if (binaryResult->Data.size() == 0) + { + GTEST_LOG_(ERROR) << "Received empty frame at offset " << iteration << std::endl; + } + EXPECT_EQ(sendData.size(), binaryResult->Data.size()); + { + // There is no ordering expectation on the results, so we just remember the data + // as it comes in. We'll make sure we received everything later on. + if (iteration < receivedData.size()) + { + if (receivedData[iteration].size() != 0) + { + GTEST_LOG_(ERROR) << "Overwriting receive frame at offset " << iteration + << std::endl; + } + + EXPECT_EQ(0, receivedData[iteration].size()); + receivedData[iteration] = binaryResult->Data; + } + } + } while (std::chrono::system_clock::now() - startTime < testDuration); + // Close the socket gracefully. + testSocket.Close(); + } + catch (Azure::Core::OperationCancelledException& ex) + { + GTEST_LOG_(ERROR) << "Cancelled Exception: " << ex.what() << " at index " << iteration + << " Current Thread: " << std::this_thread::get_id() << std::endl; + cancellationExceptions++; + } + catch (std::exception const& ex) + { + GTEST_LOG_(ERROR) << "Exception: " << ex.what() << std::endl; + exceptions++; + } + })); + } + + // Wait for all the threads to exit. + for (auto& thread : threads) + { + thread.join(); + } + + // We no longer need to worry about synchronization since all the worker threads are done. + GTEST_LOG_(INFO) << "Total server requests: " << iterationCount.load() << std::endl; + GTEST_LOG_(INFO) << "Estimated " << std::dec << testData.size() << " iterations (0x" << std::hex + << testData.size() << ")" << std::endl; + EXPECT_GE(testDataLength, iterationCount.load()); + + // Resize the test data to the number of actual iterations. + testData.resize(iterationCount.load()); + receivedData.resize(iterationCount.load()); + + // If we've processed every iteration, let's make sure that we received everything we sent. + // If we dropped some results, then we can't check to ensure that we have received everything + // because we can't account for everything sent. + std::multiset testDataStrings; + std::multiset receivedDataStrings; + for (auto const& data : testData) + { + testDataStrings.emplace(ToHexString(data)); + } + for (auto const& data : receivedData) + { + receivedDataStrings.emplace(ToHexString(data)); + } + + EXPECT_EQ(testDataStrings, receivedDataStrings); + for (auto const& data : testDataStrings) + { + if (receivedDataStrings.count(data) != testDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Missing data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + EXPECT_NE(receivedDataStrings.end(), receivedDataStrings.find(data)); + } + for (auto const& data : receivedDataStrings) + { + if (testDataStrings.count(data) != receivedDataStrings.count(data)) + { + GTEST_LOG_(INFO) << "Extra data. TestDataCount: " << testDataStrings.count(data) + << " ReceivedDataCount: " << receivedDataStrings.count(data) + << " Missing Data: " << data << std::endl; + } + + EXPECT_NE(testDataStrings.end(), testDataStrings.find(data)); + } + + // We shouldn't have seen any exceptions during the run. + EXPECT_EQ(0, exceptions.load()); + EXPECT_EQ(0, cancellationExceptions.load()); +} + +// Does not work because curl rejects the wss: scheme. +class LibWebSocketIncrementProtocol { + WebSocketOptions m_options{{"dumb-increment-protocol"}}; + WebSocket m_socket; + +public: + LibWebSocketIncrementProtocol() : m_socket{Azure::Core::Url("wss://libwebsockets.org"), m_options} + { + } + + void Open() { m_socket.Open(); } + int GetNextNumber() + { + // Time out in 5 seconds if no activity. + Azure::Core::Context contextWithTimeout + = Azure::Core::Context().WithDeadline(std::chrono::system_clock::now() + 10s); + auto work = m_socket.ReceiveFrame(contextWithTimeout); + if (work->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto frame = work->AsTextFrame(); + return std::atoi(frame->Text.c_str()); + } + if (work->FrameType == WebSocketFrameType::BinaryFrameReceived) + { + auto frame = work->AsBinaryFrame(); + throw std::runtime_error("Not implemented"); + } + else if (work->FrameType == WebSocketFrameType::PeerClosedReceived) + { + GTEST_LOG_(INFO) << "Remote server closed connection." << std::endl; + throw std::runtime_error("Remote server closed connection."); + } + else + { + throw std::runtime_error("Unknown result type"); + } + } + + void Reset() { m_socket.SendFrame("reset\n", true); } + void RequestClose() { m_socket.SendFrame("closeme\n", true); } + void Close() { m_socket.Close(); } + void Close(uint16_t closeCode, std::string const& reasonText = {}) + { + m_socket.Close(closeCode, reasonText); + } + void ConsumeUntilClosed() + { + while (m_socket.IsOpen()) + { + auto work = m_socket.ReceiveFrame(); + if (work->FrameType == WebSocketFrameType::PeerClosedReceived) + { + auto peerClose = work->AsPeerCloseFrame(); + GTEST_LOG_(INFO) << "Peer closed. Remote Code: " << std::dec << peerClose->RemoteStatusCode + << " (0x" << std::hex << peerClose->RemoteStatusCode << ")" << std::endl; + if (!peerClose->RemoteCloseReason.empty()) + { + GTEST_LOG_(INFO) << " Peer Closed Data: " << peerClose->RemoteCloseReason; + } + GTEST_LOG_(INFO) << std::endl; + return; + } + else if (work->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto frame = work->AsTextFrame(); + GTEST_LOG_(INFO) << "Ignoring " << frame->Text << std::endl; + } + } + } +}; + +class LibWebSocketStatus { + +public: + std::string GetLWSStatus() + { + WebSocketOptions options; + + options.ServiceName = "websockettest"; + // Send 3 protocols to LWS. + options.Protocols.push_back("brownCow"); + options.Protocols.push_back("lws-status"); + options.Protocols.push_back("flibbityflobbidy"); + WebSocket serverSocket(Azure::Core::Url("wss://libwebsockets.org"), options); + serverSocket.Open(); + + // The server should have chosen the lws-status protocol since it doesn't understand the other + // protocols. + EXPECT_EQ("lws-status", serverSocket.GetNegotiatedProtocol()); + std::string returnValue; + std::shared_ptr lwsStatus; + do + { + + lwsStatus = serverSocket.ReceiveFrame(); + EXPECT_EQ(WebSocketFrameType::TextFrameReceived, lwsStatus->FrameType); + if (lwsStatus->FrameType == WebSocketFrameType::TextFrameReceived) + { + auto textFrame = lwsStatus->AsTextFrame(); + returnValue.insert(returnValue.end(), textFrame->Text.begin(), textFrame->Text.end()); + } + } while (!lwsStatus->IsFinalFrame); + serverSocket.Close(); + return returnValue; + } +}; + +TEST_F(WebSocketTests, LibWebSocketOrgLwsStatus) +{ + { + LibWebSocketStatus lwsStatus; + auto serverStatus = lwsStatus.GetLWSStatus(); + GTEST_LOG_(INFO) << "Server status: " << serverStatus << std::endl; + + Azure::Core::Json::_internal::json status; + EXPECT_NO_THROW(status = Azure::Core::Json::_internal::json::parse(serverStatus)); + EXPECT_TRUE(status["conns"].is_array()); + auto& connections = status["conns"].get_ref&>(); + bool foundOurConnection = false; + + // Scan through the list of connections to find a connection from the websockettest. + for (auto& connection : connections) + { + EXPECT_TRUE(connection["ua"].is_string()); + auto userAgent = connection["ua"].get(); + if (userAgent.find("websockettest") != std::string::npos) + { + foundOurConnection = true; + break; + } + } + EXPECT_TRUE(foundOurConnection); + } +} +TEST_F(WebSocketTests, LibWebSocketOrgIncrement) +{ + { + LibWebSocketIncrementProtocol incrementProtocol; + incrementProtocol.Open(); + + // Note that we cannot practically validate the numbers received from the service because + // they may be in flight at the time the "Reset" call is made. + for (auto i = 0; i < 100; i += 1) + { + if (i % 5 == 0) + { + GTEST_LOG_(INFO) << "Reset" << std::endl; + incrementProtocol.Reset(); + } + int number = incrementProtocol.GetNextNumber(); + GTEST_LOG_(INFO) << "Got next number " << number << std::endl; + } + incrementProtocol.RequestClose(); + incrementProtocol.ConsumeUntilClosed(); + } +} +#if defined(BUILD_CURL_HTTP_TRANSPORT_ADAPTER) +TEST_F(WebSocketTests, CurlTransportCoverage) +{ + { + + Azure::Core::Http::WebSockets::CurlWebSocketTransportOptions transportOptions; + transportOptions.HttpKeepAlive = false; + auto transport + = std::make_shared(transportOptions); + + EXPECT_THROW(transport->NativeCloseSocket(1001, {}, {}), std::runtime_error); + EXPECT_THROW(transport->NativeGetCloseSocketInformation({}), std::runtime_error); + EXPECT_THROW( + transport->NativeSendFrame(WebSocketTransport::NativeWebSocketFrameType::Binary, {}, {}), + std::runtime_error); + EXPECT_THROW(transport->NativeReceiveFrame({}), std::runtime_error); + } +} +#endif diff --git a/sdk/core/ci.yml b/sdk/core/ci.yml index 8547868d47..34be7a63ef 100644 --- a/sdk/core/ci.yml +++ b/sdk/core/ci.yml @@ -42,6 +42,56 @@ stages: LiveTestTimeoutInMinutes: 90 # default is 60 min. We need a little longer on worst case for Win+jsonTests LineCoverageTarget: 93 BranchCoverageTarget: 55 + PreTestSteps: + - task: UsePythonVersion@0 + displayName: 'Use Python 3' + inputs: + versionSpec: '3' + condition: and(succeeded(), contains(variables.CmakeArgs, 'BUILD_TESTING=ON')) + + - pwsh: | + python --version + pip install -r requirements.txt + workingDirectory: build/sdk/core/azure-core/test/ut + displayName: Install Python requirements. + condition: and(succeeded(), contains(variables.CmakeArgs, 'BUILD_TESTING=ON')) + - task: PowerShell@2 + displayName: 'Launch python websocket server' + inputs: + pwsh: true + filePath: build/sdk/core/azure-core/test/ut/Start-WebSocketServer.ps1 + arguments: $(Build.SourcesDirectory)/WebSocketServer.log + workingDirectory: build/sdk/core/azure-core/test/ut + condition: and(succeeded(), eq(variables['Agent.OS'], 'Windows_NT'), contains(variables.CmakeArgs, 'BUILD_TESTING=ON')) + # It would be nice to collapse this branch with the previous one, but nohup doesn't seem to + # behave when called from powershell. + - bash: | + nohup python sdk/core/azure-core/test/ut/websocket_server.py > $(Build.SourcesDirectory)/WebSocketServer.log & + workingDirectory: build + condition: and(succeeded(), ne(variables['Agent.OS'], 'Windows_NT'), contains(variables.CmakeArgs, 'BUILD_TESTING=ON')) + displayName: Launch python websocket server (Linux). + PostTestSteps: + # Shut down the test server. This uses curl to send a request to the "terminateserver" websocket endpoint. + # When the test server receives a request on terminateserver, it shuts down gracefully. + - pwsh: | + curl ` + --include ` + --no-buffer ` + --header "Connection: Upgrade" ` + --header "Upgrade: websocket" ` + --header "Host: localhost:8000" ` + --header "Origin: http://localhost:8000" ` + --header "Sec-WebSocket-Key: eaQZ9ed+LnT0zs5EvI04aQ==" ` + --header "Sec-WebSocket-Version: 13" ` + http://localhost:8000/terminateserver + displayName: Shutdown WebSocket server. + condition: contains(variables.CmakeArgs, 'BUILD_TESTING=ON') + - template: /eng/common/pipelines/templates/steps/publish-artifact.yml + parameters: + ArtifactPath: '$(Build.SourcesDirectory)/WebSocketServer.log' + ArtifactName: 'WebSocketLogs-$(Agent.JobName)_attempt_$(System.JobAttempt)' + CustomCondition: contains(variables.CmakeArgs, 'BUILD_TESTING=ON') + Artifacts: - Name: azure-core Path: azure-core