From d63247bbeb2026ff1f994b54611225f164996711 Mon Sep 17 00:00:00 2001 From: Alyssa Wilk Date: Tue, 27 Oct 2020 16:11:04 -0400 Subject: [PATCH] network: adding some accessors for ALPN work. Signed-off-by: Alyssa Wilk --- include/envoy/network/connection.h | 11 ++++++++ include/envoy/network/filter.h | 5 ++++ source/common/network/connection_impl.cc | 12 +++++++-- source/common/network/connection_impl.h | 2 ++ source/common/network/connection_impl_base.cc | 14 ++++++++++- source/common/network/connection_impl_base.h | 1 + source/common/network/filter_manager_impl.cc | 12 +++++++++ source/common/network/filter_manager_impl.h | 1 + .../quic_filter_manager_connection_impl.cc | 4 +++ .../quic_filter_manager_connection_impl.h | 7 ++++++ source/server/api_listener_impl.h | 7 ++++++ test/common/network/connection_impl_test.cc | 25 +++++++++++++++++++ test/mocks/network/connection.h | 9 +++++++ 13 files changed, 107 insertions(+), 3 deletions(-) diff --git a/include/envoy/network/connection.h b/include/envoy/network/connection.h index b486f614ed964..7606d30a725cd 100644 --- a/include/envoy/network/connection.h +++ b/include/envoy/network/connection.h @@ -101,6 +101,11 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual void addConnectionCallbacks(ConnectionCallbacks& cb) PURE; + /** + * Unregister callbacks which previously fired when connection events occur. + */ + virtual void removeConnectionCallbacks(ConnectionCallbacks& cb) PURE; + /** * Register for callback every time bytes are written to the underlying TransportSocket. */ @@ -241,6 +246,12 @@ class Connection : public Event::DeferredDeletable, public FilterManager { */ virtual State state() const PURE; + /** + * @return true if the connection has not completed connecting, false if the connection is + * established. + */ + virtual bool connecting() const PURE; + /** * Write data to the connection. Will iterate through downstream filters with the buffer if any * are installed. diff --git a/include/envoy/network/filter.h b/include/envoy/network/filter.h index a111b1a22ed47..f34aba5c1acae 100644 --- a/include/envoy/network/filter.h +++ b/include/envoy/network/filter.h @@ -227,6 +227,11 @@ class FilterManager { */ virtual void addReadFilter(ReadFilterSharedPtr filter) PURE; + /** + * Remove a read filter from the connection. + */ + virtual void removeReadFilter(ReadFilterSharedPtr filter) PURE; + /** * Initialize all of the installed read filters. This effectively calls onNewConnection() on * each of them. diff --git a/source/common/network/connection_impl.cc b/source/common/network/connection_impl.cc index 2804bad755ba1..0e79089b7d374 100644 --- a/source/common/network/connection_impl.cc +++ b/source/common/network/connection_impl.cc @@ -95,6 +95,10 @@ void ConnectionImpl::addReadFilter(ReadFilterSharedPtr filter) { filter_manager_.addReadFilter(filter); } +void ConnectionImpl::removeReadFilter(ReadFilterSharedPtr filter) { + filter_manager_.removeReadFilter(filter); +} + bool ConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); } void ConnectionImpl::close(ConnectionCloseType type) { @@ -484,7 +488,9 @@ void ConnectionImpl::onWriteBufferLowWatermark() { ASSERT(write_buffer_above_high_watermark_); write_buffer_above_high_watermark_ = false; for (ConnectionCallbacks* callback : callbacks_) { - callback->onBelowWriteBufferLowWatermark(); + if (callback) { + callback->onBelowWriteBufferLowWatermark(); + } } } @@ -493,7 +499,9 @@ void ConnectionImpl::onWriteBufferHighWatermark() { ASSERT(!write_buffer_above_high_watermark_); write_buffer_above_high_watermark_ = true; for (ConnectionCallbacks* callback : callbacks_) { - callback->onAboveWriteBufferHighWatermark(); + if (callback) { + callback->onAboveWriteBufferHighWatermark(); + } } } diff --git a/source/common/network/connection_impl.h b/source/common/network/connection_impl.h index e28e05e9d1822..604c0077be9fb 100644 --- a/source/common/network/connection_impl.h +++ b/source/common/network/connection_impl.h @@ -55,6 +55,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback void addWriteFilter(WriteFilterSharedPtr filter) override; void addFilter(FilterSharedPtr filter) override; void addReadFilter(ReadFilterSharedPtr filter) override; + void removeReadFilter(ReadFilterSharedPtr filter) override; bool initializeReadFilters() override; // Network::Connection @@ -78,6 +79,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback absl::optional unixSocketPeerCredentials() const override; Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); } State state() const override; + bool connecting() const override { return connecting_; } void write(Buffer::Instance& data, bool end_stream) override; void setBufferLimits(uint32_t limit) override; uint32_t bufferLimit() const override { return read_buffer_limit_; } diff --git a/source/common/network/connection_impl_base.cc b/source/common/network/connection_impl_base.cc index e048465a4b35a..775b09be13e40 100644 --- a/source/common/network/connection_impl_base.cc +++ b/source/common/network/connection_impl_base.cc @@ -18,6 +18,16 @@ void ConnectionImplBase::addConnectionCallbacks(ConnectionCallbacks& cb) { callbacks_.push_back(&cb); } +void ConnectionImplBase::removeConnectionCallbacks(ConnectionCallbacks& callbacks) { + // For performance/safety reasons we just clear the callback and do not resize the list + for (auto& callback : callbacks_) { + if (callback == &callbacks) { + callback = nullptr; + return; + } + } +} + void ConnectionImplBase::hashKey(std::vector& hash) const { addIdToHashKey(hash, id()); } void ConnectionImplBase::setConnectionStats(const ConnectionStats& stats) { @@ -45,7 +55,9 @@ void ConnectionImplBase::raiseConnectionEvent(ConnectionEvent event) { for (ConnectionCallbacks* callback : callbacks_) { // TODO(mattklein123): If we close while raising a connected event we should not raise further // connected events. - callback->onEvent(event); + if (callback != nullptr) { + callback->onEvent(event); + } } } diff --git a/source/common/network/connection_impl_base.h b/source/common/network/connection_impl_base.h index d0bf93670cffb..5bb12eea5a7d9 100644 --- a/source/common/network/connection_impl_base.h +++ b/source/common/network/connection_impl_base.h @@ -22,6 +22,7 @@ class ConnectionImplBase : public FilterManagerConnection, // Network::Connection void addConnectionCallbacks(ConnectionCallbacks& cb) override; + void removeConnectionCallbacks(ConnectionCallbacks& cb) override; Event::Dispatcher& dispatcher() override { return dispatcher_; } uint64_t id() const override { return id_; } void hashKey(std::vector& hash) const override; diff --git a/source/common/network/filter_manager_impl.cc b/source/common/network/filter_manager_impl.cc index 593abc0980951..5bec16f5f75ae 100644 --- a/source/common/network/filter_manager_impl.cc +++ b/source/common/network/filter_manager_impl.cc @@ -28,6 +28,15 @@ void FilterManagerImpl::addReadFilter(ReadFilterSharedPtr filter) { LinkedList::moveIntoListBack(std::move(new_filter), upstream_filters_); } +void FilterManagerImpl::removeReadFilter(ReadFilterSharedPtr filter_to_remove) { + // For perf/safety reasons, null this out rather than removing. + for (auto& filter : upstream_filters_) { + if (filter->filter_ == filter_to_remove) { + filter->filter_ = nullptr; + } + } +} + bool FilterManagerImpl::initializeReadFilters() { if (upstream_filters_.empty()) { return false; @@ -53,6 +62,9 @@ void FilterManagerImpl::onContinueReading(ActiveReadFilter* filter, } for (; entry != upstream_filters_.end(); entry++) { + if (!(*entry)->filter_) { + continue; + } if (!(*entry)->initialized_) { (*entry)->initialized_ = true; FilterStatus status = (*entry)->filter_->onNewConnection(); diff --git a/source/common/network/filter_manager_impl.h b/source/common/network/filter_manager_impl.h index 0975c2ecd7ede..a74ba02c56c58 100644 --- a/source/common/network/filter_manager_impl.h +++ b/source/common/network/filter_manager_impl.h @@ -105,6 +105,7 @@ class FilterManagerImpl { void addWriteFilter(WriteFilterSharedPtr filter); void addFilter(FilterSharedPtr filter); void addReadFilter(ReadFilterSharedPtr filter); + void removeReadFilter(ReadFilterSharedPtr filter); bool initializeReadFilters(); void onRead(); FilterStatus onWrite(); diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc index e005a3dd7691c..3e30e6ec5779f 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.cc @@ -30,6 +30,10 @@ void QuicFilterManagerConnectionImpl::addReadFilter(Network::ReadFilterSharedPtr filter_manager_.addReadFilter(filter); } +void QuicFilterManagerConnectionImpl::removeReadFilter(Network::ReadFilterSharedPtr filter) { + filter_manager_.removeReadFilter(filter); +} + bool QuicFilterManagerConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); } diff --git a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h index cf049ab5ac523..8f01d03ca6b9c 100644 --- a/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h +++ b/source/extensions/quic_listeners/quiche/quic_filter_manager_connection_impl.h @@ -25,6 +25,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { void addWriteFilter(Network::WriteFilterSharedPtr filter) override; void addFilter(Network::FilterSharedPtr filter) override; void addReadFilter(Network::ReadFilterSharedPtr filter) override; + void removeReadFilter(Network::ReadFilterSharedPtr filter) override; bool initializeReadFilters() override; // Network::Connection @@ -63,6 +64,12 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase { } return Network::Connection::State::Closed; } + bool connecting() const override { + if (quic_connection_ != nullptr && quic_connection_->connected()) { + return false; + } + return true; + } void write(Buffer::Instance& /*data*/, bool /*end_stream*/) override { // All writes should be handled by Quic internally. NOT_REACHED_GCOVR_EXCL_LINE; diff --git a/source/server/api_listener_impl.h b/source/server/api_listener_impl.h index 4731ea90ca54a..b0dd0ef701c34 100644 --- a/source/server/api_listener_impl.h +++ b/source/server/api_listener_impl.h @@ -83,12 +83,18 @@ class ApiListenerImplBase : public ApiListener, } void addFilter(Network::FilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void addReadFilter(Network::ReadFilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + void removeReadFilter(Network::ReadFilterSharedPtr) override { + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } bool initializeReadFilters() override { return true; } // Network::Connection void addConnectionCallbacks(Network::ConnectionCallbacks& cb) override { callbacks_.push_back(&cb); } + void removeConnectionCallbacks(Network::ConnectionCallbacks& cb) override { + callbacks_.remove(&cb); + } void addBytesSentCallback(Network::Connection::BytesSentCb) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } @@ -121,6 +127,7 @@ class ApiListenerImplBase : public ApiListener, Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; } absl::string_view requestedServerName() const override { return EMPTY_STRING; } State state() const override { return Network::Connection::State::Open; } + bool connecting() const override { return false; } void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } void setBufferLimits(uint32_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } uint32_t bufferLimit() const override { return 65000; } diff --git a/test/common/network/connection_impl_test.cc b/test/common/network/connection_impl_test.cc index 086551ade5927..2d2f0a646942e 100644 --- a/test/common/network/connection_impl_test.cc +++ b/test/common/network/connection_impl_test.cc @@ -306,6 +306,15 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { Buffer::OwnedImpl buffer("hello world"); client_connection_->write(buffer, false); client_connection_->connect(); + EXPECT_TRUE(client_connection_->connecting()); + + StrictMock added_and_removed_callbacks; + // Make sure removed connections don't get events. + client_connection_->addConnectionCallbacks(added_and_removed_callbacks); + client_connection_->removeConnectionCallbacks(added_and_removed_callbacks); + + std::shared_ptr add_and_remove_filter = + std::make_shared>(); EXPECT_CALL(client_callbacks_, onEvent(ConnectionEvent::Connected)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { @@ -321,6 +330,8 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) { std::move(socket), Network::Test::createRawBufferSocket(), stream_info_); server_connection_->addConnectionCallbacks(server_callbacks_); server_connection_->addReadFilter(read_filter_); + server_connection_->addReadFilter(add_and_remove_filter); + server_connection_->removeReadFilter(add_and_remove_filter); })); EXPECT_CALL(server_callbacks_, onEvent(ConnectionEvent::RemoteClose)) @@ -497,13 +508,22 @@ TEST_P(ConnectionImplTest, ConnectionStats) { MockConnectionStats client_connection_stats; client_connection_->setConnectionStats(client_connection_stats.toBufferStats()); + EXPECT_TRUE(client_connection_->connecting()); client_connection_->connect(); + // The Network::Connection class oddly uses onWrite as its indicator of if + // it's done connection, rather than the Connected event. + EXPECT_TRUE(client_connection_->connecting()); std::shared_ptr write_filter(new MockWriteFilter()); std::shared_ptr filter(new MockFilter()); client_connection_->addFilter(filter); client_connection_->addWriteFilter(write_filter); + // Make sure removed filters don't get callbacks. + std::shared_ptr read_filter(new StrictMock()); + client_connection_->addReadFilter(read_filter); + client_connection_->removeReadFilter(read_filter); + Sequence s1; EXPECT_CALL(*write_filter, onWrite(_, _)) .InSequence(s1) @@ -814,6 +834,11 @@ TEST_P(ConnectionImplTest, WriteWatermarks) { setUpBasicConnection(); EXPECT_FALSE(client_connection_->aboveHighWatermark()); + StrictMock added_and_removed_callbacks; + // Make sure removed connections don't get events. + client_connection_->addConnectionCallbacks(added_and_removed_callbacks); + client_connection_->removeConnectionCallbacks(added_and_removed_callbacks); + // Stick 5 bytes in the connection buffer. std::unique_ptr buffer(new Buffer::OwnedImpl("hello")); int buffer_len = buffer->length(); diff --git a/test/mocks/network/connection.h b/test/mocks/network/connection.h index 6a7856887fe03..0c9d9ee654e00 100644 --- a/test/mocks/network/connection.h +++ b/test/mocks/network/connection.h @@ -52,10 +52,12 @@ class MockConnection : public Connection, public MockConnectionBase { // Network::Connection MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); + MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); + MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); MOCK_METHOD(void, enableHalfClose, (bool enabled)); MOCK_METHOD(void, close, (ConnectionCloseType type)); MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); @@ -76,6 +78,7 @@ class MockConnection : public Connection, public MockConnectionBase { MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); + MOCK_METHOD(bool, connecting, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); MOCK_METHOD(uint32_t, bufferLimit, (), (const)); @@ -100,10 +103,12 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase // Network::Connection MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); + MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); + MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); MOCK_METHOD(void, enableHalfClose, (bool enabled)); MOCK_METHOD(void, close, (ConnectionCloseType type)); MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); @@ -124,6 +129,7 @@ class MockClientConnection : public ClientConnection, public MockConnectionBase MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); + MOCK_METHOD(bool, connecting, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); MOCK_METHOD(uint32_t, bufferLimit, (), (const)); @@ -151,10 +157,12 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC // Network::Connection MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); + MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); + MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); MOCK_METHOD(void, enableHalfClose, (bool enabled)); MOCK_METHOD(void, close, (ConnectionCloseType type)); MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); @@ -175,6 +183,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); MOCK_METHOD(State, state, (), (const)); + MOCK_METHOD(bool, connecting, (), (const)); MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); MOCK_METHOD(uint32_t, bufferLimit, (), (const));