Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions include/envoy/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual void addConnectionCallbacks(ConnectionCallbacks& cb) PURE;

/**
* Unregister callbacks which previously fired when connection events occur.
*/
virtual void removeConnectionCallbacks(ConnectionCallbacks& cb) PURE;

/**
* Register for callback every time bytes are written to the underlying TransportSocket.
*/
Expand Down Expand Up @@ -241,6 +246,12 @@ class Connection : public Event::DeferredDeletable, public FilterManager {
*/
virtual State state() const PURE;

/**
* @return true if the connection has not completed connecting, false if the connection is
* established.
*/
virtual bool connecting() const PURE;

/**
* Write data to the connection. Will iterate through downstream filters with the buffer if any
* are installed.
Expand Down
5 changes: 5 additions & 0 deletions include/envoy/network/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -227,6 +227,11 @@ class FilterManager {
*/
virtual void addReadFilter(ReadFilterSharedPtr filter) PURE;

/**
* Remove a read filter from the connection.
*/
virtual void removeReadFilter(ReadFilterSharedPtr filter) PURE;
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What is the use case for this one? Naively I would assume that we can just make a raw connection without any read filters, detect ALPN, and then attach an HTTP filter, etc. and continue?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

the tcp pool by default adds a read filter, and disconnects if data is read when no session is associated.
We could skip that and the add-and-remove but I'd want something to make sure early data didn't go into the void, so I'm mildly inclined to leave as-is.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah OK. I see. I guess it would be nice to avoid this interface change but up to you.


/**
* Initialize all of the installed read filters. This effectively calls onNewConnection() on
* each of them.
Expand Down
12 changes: 10 additions & 2 deletions source/common/network/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,10 @@ void ConnectionImpl::addReadFilter(ReadFilterSharedPtr filter) {
filter_manager_.addReadFilter(filter);
}

void ConnectionImpl::removeReadFilter(ReadFilterSharedPtr filter) {
filter_manager_.removeReadFilter(filter);
}

bool ConnectionImpl::initializeReadFilters() { return filter_manager_.initializeReadFilters(); }

void ConnectionImpl::close(ConnectionCloseType type) {
Expand Down Expand Up @@ -485,7 +489,9 @@ void ConnectionImpl::onWriteBufferLowWatermark() {
ASSERT(write_buffer_above_high_watermark_);
write_buffer_above_high_watermark_ = false;
for (ConnectionCallbacks* callback : callbacks_) {
callback->onBelowWriteBufferLowWatermark();
if (callback) {
callback->onBelowWriteBufferLowWatermark();
}
}
}

Expand All @@ -494,7 +500,9 @@ void ConnectionImpl::onWriteBufferHighWatermark() {
ASSERT(!write_buffer_above_high_watermark_);
write_buffer_above_high_watermark_ = true;
for (ConnectionCallbacks* callback : callbacks_) {
callback->onAboveWriteBufferHighWatermark();
if (callback) {
callback->onAboveWriteBufferHighWatermark();
}
}
}

Expand Down
2 changes: 2 additions & 0 deletions source/common/network/connection_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback
void addWriteFilter(WriteFilterSharedPtr filter) override;
void addFilter(FilterSharedPtr filter) override;
void addReadFilter(ReadFilterSharedPtr filter) override;
void removeReadFilter(ReadFilterSharedPtr filter) override;
bool initializeReadFilters() override;

// Network::Connection
Expand All @@ -78,6 +79,7 @@ class ConnectionImpl : public ConnectionImplBase, public TransportSocketCallback
absl::optional<UnixDomainSocketPeerCredentials> unixSocketPeerCredentials() const override;
Ssl::ConnectionInfoConstSharedPtr ssl() const override { return transport_socket_->ssl(); }
State state() const override;
bool connecting() const override { return connecting_; }
void write(Buffer::Instance& data, bool end_stream) override;
void setBufferLimits(uint32_t limit) override;
uint32_t bufferLimit() const override { return read_buffer_limit_; }
Expand Down
14 changes: 13 additions & 1 deletion source/common/network/connection_impl_base.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,16 @@ void ConnectionImplBase::addConnectionCallbacks(ConnectionCallbacks& cb) {
callbacks_.push_back(&cb);
}

void ConnectionImplBase::removeConnectionCallbacks(ConnectionCallbacks& callbacks) {
// For performance/safety reasons we just clear the callback and do not resize the list
for (auto& callback : callbacks_) {
if (callback == &callbacks) {
callback = nullptr;
return;
}
}
}

void ConnectionImplBase::hashKey(std::vector<uint8_t>& hash) const { addIdToHashKey(hash, id()); }

void ConnectionImplBase::setConnectionStats(const ConnectionStats& stats) {
Expand Down Expand Up @@ -45,7 +55,9 @@ void ConnectionImplBase::raiseConnectionEvent(ConnectionEvent event) {
for (ConnectionCallbacks* callback : callbacks_) {
// TODO(mattklein123): If we close while raising a connected event we should not raise further
// connected events.
callback->onEvent(event);
if (callback != nullptr) {
callback->onEvent(event);
}
}
}

Expand Down
1 change: 1 addition & 0 deletions source/common/network/connection_impl_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ class ConnectionImplBase : public FilterManagerConnection,

// Network::Connection
void addConnectionCallbacks(ConnectionCallbacks& cb) override;
void removeConnectionCallbacks(ConnectionCallbacks& cb) override;
Event::Dispatcher& dispatcher() override { return dispatcher_; }
uint64_t id() const override { return id_; }
void hashKey(std::vector<uint8_t>& hash) const override;
Expand Down
12 changes: 12 additions & 0 deletions source/common/network/filter_manager_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,15 @@ void FilterManagerImpl::addReadFilter(ReadFilterSharedPtr filter) {
LinkedList::moveIntoListBack(std::move(new_filter), upstream_filters_);
}

void FilterManagerImpl::removeReadFilter(ReadFilterSharedPtr filter_to_remove) {
// For perf/safety reasons, null this out rather than removing.
for (auto& filter : upstream_filters_) {
if (filter->filter_ == filter_to_remove) {
filter->filter_ = nullptr;
}
}
}

bool FilterManagerImpl::initializeReadFilters() {
if (upstream_filters_.empty()) {
return false;
Expand All @@ -53,6 +62,9 @@ void FilterManagerImpl::onContinueReading(ActiveReadFilter* filter,
}

for (; entry != upstream_filters_.end(); entry++) {
if (!(*entry)->filter_) {
continue;
}
if (!(*entry)->initialized_) {
(*entry)->initialized_ = true;
FilterStatus status = (*entry)->filter_->onNewConnection();
Expand Down
1 change: 1 addition & 0 deletions source/common/network/filter_manager_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,7 @@ class FilterManagerImpl {
void addWriteFilter(WriteFilterSharedPtr filter);
void addFilter(FilterSharedPtr filter);
void addReadFilter(ReadFilterSharedPtr filter);
void removeReadFilter(ReadFilterSharedPtr filter);
bool initializeReadFilters();
void onRead();
FilterStatus onWrite();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,10 @@ void QuicFilterManagerConnectionImpl::addReadFilter(Network::ReadFilterSharedPtr
filter_manager_.addReadFilter(filter);
}

void QuicFilterManagerConnectionImpl::removeReadFilter(Network::ReadFilterSharedPtr filter) {
filter_manager_.removeReadFilter(filter);
}

bool QuicFilterManagerConnectionImpl::initializeReadFilters() {
return filter_manager_.initializeReadFilters();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase {
void addWriteFilter(Network::WriteFilterSharedPtr filter) override;
void addFilter(Network::FilterSharedPtr filter) override;
void addReadFilter(Network::ReadFilterSharedPtr filter) override;
void removeReadFilter(Network::ReadFilterSharedPtr filter) override;
bool initializeReadFilters() override;

// Network::Connection
Expand Down Expand Up @@ -63,6 +64,12 @@ class QuicFilterManagerConnectionImpl : public Network::ConnectionImplBase {
}
return Network::Connection::State::Closed;
}
bool connecting() const override {
if (quic_connection_ != nullptr && quic_connection_->connected()) {
return false;
}
return true;
}
void write(Buffer::Instance& /*data*/, bool /*end_stream*/) override {
// All writes should be handled by Quic internally.
NOT_REACHED_GCOVR_EXCL_LINE;
Expand Down
7 changes: 7 additions & 0 deletions source/server/api_listener_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -83,12 +83,18 @@ class ApiListenerImplBase : public ApiListener,
}
void addFilter(Network::FilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void addReadFilter(Network::ReadFilterSharedPtr) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void removeReadFilter(Network::ReadFilterSharedPtr) override {
NOT_IMPLEMENTED_GCOVR_EXCL_LINE;
}
bool initializeReadFilters() override { return true; }

// Network::Connection
void addConnectionCallbacks(Network::ConnectionCallbacks& cb) override {
callbacks_.push_back(&cb);
}
void removeConnectionCallbacks(Network::ConnectionCallbacks& cb) override {
callbacks_.remove(&cb);
}
void addBytesSentCallback(Network::Connection::BytesSentCb) override {
NOT_IMPLEMENTED_GCOVR_EXCL_LINE;
}
Expand Down Expand Up @@ -121,6 +127,7 @@ class ApiListenerImplBase : public ApiListener,
Ssl::ConnectionInfoConstSharedPtr ssl() const override { return nullptr; }
absl::string_view requestedServerName() const override { return EMPTY_STRING; }
State state() const override { return Network::Connection::State::Open; }
bool connecting() const override { return false; }
void write(Buffer::Instance&, bool) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
void setBufferLimits(uint32_t) override { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; }
uint32_t bufferLimit() const override { return 65000; }
Expand Down
25 changes: 25 additions & 0 deletions test/common/network/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -287,6 +287,15 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) {
Buffer::OwnedImpl buffer("hello world");
client_connection_->write(buffer, false);
client_connection_->connect();
EXPECT_TRUE(client_connection_->connecting());

StrictMock<MockConnectionCallbacks> added_and_removed_callbacks;
// Make sure removed connections don't get events.
client_connection_->addConnectionCallbacks(added_and_removed_callbacks);
client_connection_->removeConnectionCallbacks(added_and_removed_callbacks);

std::shared_ptr<MockReadFilter> add_and_remove_filter =
std::make_shared<StrictMock<MockReadFilter>>();

EXPECT_CALL(client_callbacks_, onEvent(ConnectionEvent::Connected))
.WillOnce(Invoke([&](Network::ConnectionEvent) -> void {
Expand All @@ -302,6 +311,8 @@ TEST_P(ConnectionImplTest, CloseDuringConnectCallback) {
std::move(socket), Network::Test::createRawBufferSocket(), stream_info_);
server_connection_->addConnectionCallbacks(server_callbacks_);
server_connection_->addReadFilter(read_filter_);
server_connection_->addReadFilter(add_and_remove_filter);
server_connection_->removeReadFilter(add_and_remove_filter);
}));

EXPECT_CALL(server_callbacks_, onEvent(ConnectionEvent::RemoteClose))
Expand Down Expand Up @@ -537,13 +548,22 @@ TEST_P(ConnectionImplTest, ConnectionStats) {

MockConnectionStats client_connection_stats;
client_connection_->setConnectionStats(client_connection_stats.toBufferStats());
EXPECT_TRUE(client_connection_->connecting());
client_connection_->connect();
// The Network::Connection class oddly uses onWrite as its indicator of if
// it's done connection, rather than the Connected event.
EXPECT_TRUE(client_connection_->connecting());

std::shared_ptr<MockWriteFilter> write_filter(new MockWriteFilter());
std::shared_ptr<MockFilter> filter(new MockFilter());
client_connection_->addFilter(filter);
client_connection_->addWriteFilter(write_filter);

// Make sure removed filters don't get callbacks.
std::shared_ptr<MockReadFilter> read_filter(new StrictMock<MockReadFilter>());
client_connection_->addReadFilter(read_filter);
client_connection_->removeReadFilter(read_filter);

Sequence s1;
EXPECT_CALL(*write_filter, onWrite(_, _))
.InSequence(s1)
Expand Down Expand Up @@ -854,6 +874,11 @@ TEST_P(ConnectionImplTest, WriteWatermarks) {
setUpBasicConnection();
EXPECT_FALSE(client_connection_->aboveHighWatermark());

StrictMock<MockConnectionCallbacks> added_and_removed_callbacks;
// Make sure removed connections don't get events.
client_connection_->addConnectionCallbacks(added_and_removed_callbacks);
client_connection_->removeConnectionCallbacks(added_and_removed_callbacks);

// Stick 5 bytes in the connection buffer.
std::unique_ptr<Buffer::OwnedImpl> buffer(new Buffer::OwnedImpl("hello"));
int buffer_len = buffer->length();
Expand Down
6 changes: 6 additions & 0 deletions test/mocks/network/connection.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,12 @@ class MockConnectionBase {
#define DEFINE_MOCK_CONNECTION_MOCK_METHODS \
/* Network::Connection */ \
MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb)); \
MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb)); \
MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb)); \
MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter)); \
MOCK_METHOD(void, addFilter, (FilterSharedPtr filter)); \
MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter)); \
MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter)); \
MOCK_METHOD(void, enableHalfClose, (bool enabled)); \
MOCK_METHOD(void, close, (ConnectionCloseType type)); \
MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); \
Expand All @@ -72,6 +74,7 @@ class MockConnectionBase {
MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const)); \
MOCK_METHOD(absl::string_view, requestedServerName, (), (const)); \
MOCK_METHOD(State, state, (), (const)); \
MOCK_METHOD(bool, connecting, (), (const)); \
MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream)); \
MOCK_METHOD(void, setBufferLimits, (uint32_t limit)); \
MOCK_METHOD(uint32_t, bufferLimit, (), (const)); \
Expand Down Expand Up @@ -128,10 +131,12 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC

// Network::Connection
MOCK_METHOD(void, addConnectionCallbacks, (ConnectionCallbacks & cb));
MOCK_METHOD(void, removeConnectionCallbacks, (ConnectionCallbacks & cb));
MOCK_METHOD(void, addBytesSentCallback, (BytesSentCb cb));
MOCK_METHOD(void, addWriteFilter, (WriteFilterSharedPtr filter));
MOCK_METHOD(void, addFilter, (FilterSharedPtr filter));
MOCK_METHOD(void, addReadFilter, (ReadFilterSharedPtr filter));
MOCK_METHOD(void, removeReadFilter, (ReadFilterSharedPtr filter));
MOCK_METHOD(void, enableHalfClose, (bool enabled));
MOCK_METHOD(void, close, (ConnectionCloseType type));
MOCK_METHOD(Event::Dispatcher&, dispatcher, ());
Expand All @@ -152,6 +157,7 @@ class MockFilterManagerConnection : public FilterManagerConnection, public MockC
MOCK_METHOD(Ssl::ConnectionInfoConstSharedPtr, ssl, (), (const));
MOCK_METHOD(absl::string_view, requestedServerName, (), (const));
MOCK_METHOD(State, state, (), (const));
MOCK_METHOD(bool, connecting, (), (const));
MOCK_METHOD(void, write, (Buffer::Instance & data, bool end_stream));
MOCK_METHOD(void, setBufferLimits, (uint32_t limit));
MOCK_METHOD(uint32_t, bufferLimit, (), (const));
Expand Down