diff --git a/include/envoy/ssl/handshaker.h b/include/envoy/ssl/handshaker.h index 42d20601071bc..e6d7715948f01 100644 --- a/include/envoy/ssl/handshaker.h +++ b/include/envoy/ssl/handshaker.h @@ -5,6 +5,8 @@ #include "envoy/network/connection.h" #include "envoy/network/post_io_action.h" #include "envoy/protobuf/message_validator.h" +#include "envoy/ssl/connection.h" +#include "envoy/ssl/ssl_socket_state.h" #include "openssl/ssl.h" @@ -46,9 +48,38 @@ class Handshaker { virtual Network::PostIoAction doHandshake() PURE; }; -using HandshakerSharedPtr = std::shared_ptr; -using HandshakerFactoryCb = - std::function, int, HandshakeCallbacks*)>; +/** + * Base interface for a combined `handshaker-and-connectioninfo` class + * which can both perform handshakes and provide connection-specific + * information. + */ +class HandshakerAndConnectionInfo : public Handshaker, public ConnectionInfo { +public: + /** + * Return the current SocketState. + */ + virtual SocketState state() PURE; + + /** + * Update the SocketState. + */ + virtual void setState(SocketState state) PURE; + + /** + * Returns a pointer to the SSL object. + */ + virtual SSL* ssl() const PURE; + + /** + * Returns a pointer to the HandshakeCallbacks object. + */ + virtual HandshakeCallbacks* handshakeCallbacks() PURE; +}; + +using HandshakerAndConnectionInfoSharedPtr = std::shared_ptr; + +using HandshakerFactoryCb = std::function, int, HandshakeCallbacks*)>; class HandshakerFactoryContext { public: diff --git a/source/extensions/transport_sockets/tls/ssl_handshaker.h b/source/extensions/transport_sockets/tls/ssl_handshaker.h index 8eaec861a8f13..4628d86a48efb 100644 --- a/source/extensions/transport_sockets/tls/ssl_handshaker.h +++ b/source/extensions/transport_sockets/tls/ssl_handshaker.h @@ -37,7 +37,7 @@ class SslExtendedSocketInfoImpl : public Envoy::Ssl::SslExtendedSocketInfo { Envoy::Ssl::ClientValidationStatus::NotValidated}; }; -class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker { +class SslHandshakerImpl : public Ssl::HandshakerAndConnectionInfo { public: SslHandshakerImpl(bssl::UniquePtr ssl, int ssl_extended_socket_info_index, Ssl::HandshakeCallbacks* handshake_callbacks); @@ -67,10 +67,11 @@ class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker { // Ssl::Handshaker Network::PostIoAction doHandshake() override; - Ssl::SocketState state() { return state_; } - void setState(Ssl::SocketState state) { state_ = state; } - SSL* ssl() const { return ssl_.get(); } - Ssl::HandshakeCallbacks* handshakeCallbacks() { return handshake_callbacks_; } + // Ssl::HandshakerAndConnectionInfo + Ssl::SocketState state() override { return state_; } + void setState(Ssl::SocketState state) override { state_ = state; } + SSL* ssl() const override { return ssl_.get(); } + Ssl::HandshakeCallbacks* handshakeCallbacks() override { return handshake_callbacks_; } bssl::UniquePtr ssl_; @@ -95,8 +96,6 @@ class SslHandshakerImpl : public Ssl::ConnectionInfo, public Ssl::Handshaker { mutable SslExtendedSocketInfoImpl extended_socket_info_; }; -using SslHandshakerImplSharedPtr = std::shared_ptr; - class HandshakerFactoryContextImpl : public Ssl::HandshakerFactoryContext { public: HandshakerFactoryContextImpl(Api::Api& api, absl::string_view alpn_protocols) diff --git a/source/extensions/transport_sockets/tls/ssl_socket.cc b/source/extensions/transport_sockets/tls/ssl_socket.cc index 4854684430963..48d088a0338fc 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.cc +++ b/source/extensions/transport_sockets/tls/ssl_socket.cc @@ -50,9 +50,8 @@ SslSocket::SslSocket(Envoy::Ssl::ContextSharedPtr ctx, InitialState state, Ssl::HandshakerFactoryCb handshaker_factory_cb) : transport_socket_options_(transport_socket_options), ctx_(std::dynamic_pointer_cast(ctx)), - info_(std::dynamic_pointer_cast( - handshaker_factory_cb(ctx_->newSsl(transport_socket_options_.get()), - ctx_->sslExtendedSocketInfoIndex(), this))) { + info_(handshaker_factory_cb(ctx_->newSsl(transport_socket_options_.get()), + ctx_->sslExtendedSocketInfoIndex(), this)) { if (state == InitialState::Client) { SSL_set_connect_state(rawSsl()); } else { diff --git a/source/extensions/transport_sockets/tls/ssl_socket.h b/source/extensions/transport_sockets/tls/ssl_socket.h index ba73cc5d6ac63..eb9f77583d9b6 100644 --- a/source/extensions/transport_sockets/tls/ssl_socket.h +++ b/source/extensions/transport_sockets/tls/ssl_socket.h @@ -72,7 +72,7 @@ class SslSocket : public Network::TransportSocket, SSL* rawSslForTest() const { return rawSsl(); } protected: - SSL* rawSsl() const { return info_->ssl_.get(); } + SSL* rawSsl() const { return info_->ssl(); } private: struct ReadResult { @@ -94,7 +94,7 @@ class SslSocket : public Network::TransportSocket, uint64_t bytes_to_retry_{}; std::string failure_reason_; - SslHandshakerImplSharedPtr info_; + Ssl::HandshakerAndConnectionInfoSharedPtr info_; }; class ClientSslSocketFactory : public Network::TransportSocketFactory,