diff --git a/include/envoy/network/transport_socket.h b/include/envoy/network/transport_socket.h index a5390c29853ad..7bb9e382b3d9e 100644 --- a/include/envoy/network/transport_socket.h +++ b/include/envoy/network/transport_socket.h @@ -137,6 +137,11 @@ class TransportSocket { * @return the const SSL connection data if this is an SSL connection, or nullptr if it is not. */ virtual const Ssl::Connection* ssl() const PURE; + + /** + * @return the TransportSocketCallbacks object passed in through setTransportSocketCallbacks(). + */ + virtual TransportSocketCallbacks* callbacks() PURE; }; typedef std::unique_ptr TransportSocketPtr; diff --git a/source/common/network/raw_buffer_socket.h b/source/common/network/raw_buffer_socket.h index 3ab5ac0a27254..c7f0f53e396a3 100644 --- a/source/common/network/raw_buffer_socket.h +++ b/source/common/network/raw_buffer_socket.h @@ -21,6 +21,7 @@ class RawBufferSocket : public TransportSocket, protected Logger::Loggablecallbacks()->fd(); } + Network::Connection& connection() override { return parent_->callbacks()->connection(); } + bool shouldDrainReadBuffer() override { return false; } + /* + * No-op for these two methods to hold back the callbacks. + */ + void setReadBufferReady() override {} + void raiseEvent(Network::ConnectionEvent) override {} + +private: + Network::TransportSocket* parent_; +}; + +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/transport_sockets/capture/capture.h b/source/extensions/transport_sockets/capture/capture.h index b46f94bf29301..73298b76986d7 100644 --- a/source/extensions/transport_sockets/capture/capture.h +++ b/source/extensions/transport_sockets/capture/capture.h @@ -27,6 +27,7 @@ class CaptureSocket : public Network::TransportSocket { void onConnected() override; Ssl::Connection* ssl() override; const Ssl::Connection* ssl() const override; + Network::TransportSocketCallbacks* callbacks() { return callbacks_; } private: const std::string& path_prefix_; diff --git a/source/extensions/transport_sockets/noop_transport_socket_callbacks.h b/source/extensions/transport_sockets/noop_transport_socket_callbacks.h new file mode 100644 index 0000000000000..ef11cf3335ba0 --- /dev/null +++ b/source/extensions/transport_sockets/noop_transport_socket_callbacks.h @@ -0,0 +1,32 @@ +#include "third_party/envoy/src/include/envoy/network/transport_socket.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { + +/** + * A TransportSocketCallbacks for wrapped TransportSocket object. Some + * TransportSocket implementation wraps another socket which does actual I/O. + * This class is used by the wrapped socket as its callbacks instead of the real + * connection to hold back callbacks from the underlying socket to connection. + */ +class NoOpTransportSocketCallbacks : public Network::TransportSocketCallbacks { +public: + explicit NoOpTransportSocketCallbacks(Network::TransportSocket* parent) : parent_(parent) {} + + int fd() const override { return parent_->callbacks()->fd(); } + Network::Connection& connection() override { return parent_->callbacks()->connection(); } + bool shouldDrainReadBuffer() override { return false; } + /* + * No-op for these two methods to hold back the callbacks. + */ + void setReadBufferReady() override {} + void raiseEvent(Network::ConnectionEvent) override {} + +private: + Network::TransportSocket* parent_; +}; + +} // namespace TransportSockets +} // namespace Extensions +} // Envoy diff --git a/test/extensions/transport_sockets/alts/BUILD b/test/extensions/transport_sockets/alts/BUILD index 171a97fd28e43..d0ef9cb73ecc4 100644 --- a/test/extensions/transport_sockets/alts/BUILD +++ b/test/extensions/transport_sockets/alts/BUILD @@ -32,3 +32,13 @@ envoy_extension_cc_test( "//test/mocks/event:event_mocks", ], ) + +envoy_extension_cc_test( + name = "noop_transport_socket_callbacks_test", + srcs = ["noop_transport_socket_callbacks_test.cc"], + extension_name = "envoy.transport_sockets.alts", + deps = [ + "//source/extensions/transport_sockets/alts:noop_transport_socket_callbacks_lib", + "//test/mocks/network:network_mocks", + ], +) diff --git a/test/extensions/transport_sockets/alts/noop_transport_socket_callbacks_test.cc b/test/extensions/transport_sockets/alts/noop_transport_socket_callbacks_test.cc new file mode 100644 index 0000000000000..09ca0a2838e63 --- /dev/null +++ b/test/extensions/transport_sockets/alts/noop_transport_socket_callbacks_test.cc @@ -0,0 +1,60 @@ +#include "extensions/transport_sockets/alts/noop_transport_socket_callbacks.h" + +#include "test/mocks/network/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace TransportSockets { +namespace Alts { +namespace { + +class TestTransportSocketCallbacks : public Network::TransportSocketCallbacks { +public: + TestTransportSocketCallbacks(Network::Connection* connection) : connection_(connection) {} + + int fd() const override { return 1; } + Network::Connection& connection() override { return *connection_; } + bool shouldDrainReadBuffer() override { return false; } + void setReadBufferReady() override { set_read_buffer_ready_ = true; } + void raiseEvent(Network::ConnectionEvent) override { event_raised_ = true; } + + bool event_raised() const { return event_raised_; } + bool set_read_buffer_ready() const { return set_read_buffer_ready_; } + +private: + bool event_raised_{false}; + bool set_read_buffer_ready_{false}; + Network::Connection* connection_; +}; + +class NoOpTransportSocketCallbacksTest : public testing::Test { +protected: + NoOpTransportSocketCallbacksTest() + : wrapper_callbacks_(&connection_), wrapped_callbacks_(&wrapper_socket_) { + wrapper_socket_.setTransportSocketCallbacks(wrapper_callbacks_); + } + + Network::MockConnection connection_; + TestTransportSocketCallbacks wrapper_callbacks_; + Network::MockTransportSocket wrapper_socket_; + NoOpTransportSocketCallbacks wrapped_callbacks_; +}; + +TEST_F(NoOpTransportSocketCallbacksTest, TestAllCallbacks) { + EXPECT_EQ(wrapper_callbacks_.fd(), wrapped_callbacks_.fd()); + EXPECT_EQ(&connection_, &wrapped_callbacks_.connection()); + EXPECT_FALSE(wrapped_callbacks_.shouldDrainReadBuffer()); + + wrapped_callbacks_.setReadBufferReady(); + EXPECT_FALSE(wrapper_callbacks_.set_read_buffer_ready()); + wrapped_callbacks_.raiseEvent(Network::ConnectionEvent::Connected); + EXPECT_FALSE(wrapper_callbacks_.event_raised()); +} + +} // namespace +} // namespace Alts +} // namespace TransportSockets +} // namespace Extensions +} // namespace Envoy diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 651da87c4174b..df74655bee272 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -428,7 +428,7 @@ class MockTransportSocket : public TransportSocket { MockTransportSocket(); ~MockTransportSocket(); - MOCK_METHOD1(setTransportSocketCallbacks, void(TransportSocketCallbacks& callbacks)); + void setTransportSocketCallbacks(TransportSocketCallbacks& callbacks) { callbacks_ = &callbacks; } MOCK_CONST_METHOD0(protocol, std::string()); MOCK_METHOD0(canFlushClose, bool()); MOCK_METHOD1(closeSocket, void(Network::ConnectionEvent event)); @@ -437,6 +437,10 @@ class MockTransportSocket : public TransportSocket { MOCK_METHOD0(onConnected, void()); MOCK_METHOD0(ssl, Ssl::Connection*()); MOCK_CONST_METHOD0(ssl, const Ssl::Connection*()); + Network::TransportSocketCallbacks* callbacks() { return callbacks_; } + +private: + Network::TransportSocketCallbacks* callbacks_; }; class MockTransportSocketFactory : public TransportSocketFactory {