Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions source/extensions/transport_sockets/alts/tsi_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result
frame_protector_ = std::make_unique<TsiFrameProtector>(frame_protector);

handshake_complete_ = true;
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
if (raw_write_buffer_.length() == 0) {
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}
}

if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
Expand All @@ -167,8 +169,8 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result
// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
if (raw_write_buffer_.length() > 0) {
Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
if (handshake_complete_ && raw_write_buffer_.length() > 0) {
write_buffer_contains_handshake_bytes_ = true;
if (handshake_complete_ && result.action_ != Network::PostIoAction::Close) {
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}
return result.action_;
}
Expand Down Expand Up @@ -266,8 +268,8 @@ Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
uint64_t total_bytes_written = 0;
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};

ASSERT(!write_buffer_contains_handshake_bytes_);
// There should be no handshake bytes in raw_write_buffer_.
ASSERT(!(raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0));
while (true) {
uint64_t bytes_to_drain_this_iteration =

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/envoyproxy/envoy/pull/16514/files

Could you merge upstream/main and consider changing the "std::min<size_t>(" to "std::min<uint64_t>("

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. Done.

prev_bytes_to_drain_ > 0
Expand Down Expand Up @@ -327,8 +329,7 @@ Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream)
} else {
ASSERT(frame_protector_);
// Check if we need to flush outstanding handshake bytes.
if (write_buffer_contains_handshake_bytes_) {
ASSERT(raw_write_buffer_.length() > 0);
if (raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0) {

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to add tests to cover the case where doHandshakeNextDone adds bytes and how this version behaves better than the case where we only set "write_buffer_contains_handshake_bytes_ = true" in one special case.

ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
Network::IoResult result =
Expand All @@ -337,7 +338,6 @@ Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream)
if (raw_write_buffer_.length() > 0) {
return {result.action_, 0, false};
}
write_buffer_contains_handshake_bytes_ = false;
}
return repeatProtectAndWrite(buffer, end_stream);
}
Expand Down
1 change: 0 additions & 1 deletion source/extensions/transport_sockets/alts/tsi_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,6 @@ class TsiSocket : public Network::TransportSocket,
bool handshake_complete_{};
bool end_stream_read_{};
bool read_error_{};
bool write_buffer_contains_handshake_bytes_{};
uint64_t prev_bytes_to_drain_{};
};

Expand Down
8 changes: 4 additions & 4 deletions test/extensions/transport_sockets/alts/tsi_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class TsiSocketTest : public testing::Test {
EXPECT_EQ(0L, client_.read_buffer_.length());

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down Expand Up @@ -342,8 +342,8 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedData) {
client_to_server_.add(makeFakeTsiFrame(ClientToServerData));

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 17UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down Expand Up @@ -378,8 +378,8 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedDataAndEndOfStream) {
buffer.move(client_to_server_);
return result;
}));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
expectIoResult({Network::PostIoAction::KeepOpen, 17UL, true},
server_.tsi_socket_->doRead(server_.read_buffer_));
EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString());
Expand Down Expand Up @@ -841,7 +841,6 @@ TEST_F(TsiSocketTest, DoWriteOutstandingHandshakeData) {
EXPECT_EQ(0L, client_.read_buffer_.length());

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));

// Write the first part of handshake data (14 bytes).
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false))
Expand All @@ -850,6 +849,7 @@ TEST_F(TsiSocketTest, DoWriteOutstandingHandshakeData) {
server_to_client_.move(buffer, 14);
return result;
}));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down