Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 9 additions & 9 deletions source/extensions/transport_sockets/alts/tsi_socket.cc
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,9 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result
frame_protector_ = std::make_unique<TsiFrameProtector>(frame_protector);

handshake_complete_ = true;
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
if (raw_write_buffer_.length() == 0) {
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}
}

if (read_error_ || (!handshake_complete_ && end_stream_read_)) {
Expand All @@ -167,8 +169,8 @@ Network::PostIoAction TsiSocket::doHandshakeNextDone(NextResultPtr&& next_result
// Try to write raw buffer when next call is done, even this is not in do[Read|Write] stack.
if (raw_write_buffer_.length() > 0) {
Network::IoResult result = raw_buffer_socket_->doWrite(raw_write_buffer_, false);
if (handshake_complete_ && raw_write_buffer_.length() > 0) {
write_buffer_contains_handshake_bytes_ = true;
if (handshake_complete_ && result.action_ != Network::PostIoAction::Close) {
callbacks_->raiseEvent(Network::ConnectionEvent::Connected);
}
return result.action_;
}
Expand Down Expand Up @@ -266,13 +268,13 @@ Network::IoResult TsiSocket::doRead(Buffer::Instance& buffer) {
Network::IoResult TsiSocket::repeatProtectAndWrite(Buffer::Instance& buffer, bool end_stream) {
uint64_t total_bytes_written = 0;
Network::IoResult result = {Network::PostIoAction::KeepOpen, 0, false};

ASSERT(!write_buffer_contains_handshake_bytes_);
// There should be no handshake bytes in raw_write_buffer_.
ASSERT(!(raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0));
while (true) {
uint64_t bytes_to_drain_this_iteration =
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See https://github.com/envoyproxy/envoy/pull/16514/files

Could you merge upstream/main and consider changing the "std::min<size_t>(" to "std::min<uint64_t>("

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure thing. Done.

prev_bytes_to_drain_ > 0
? prev_bytes_to_drain_
: std::min<size_t>(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
: std::min<uint64_t>(buffer.length(), actual_frame_size_to_use_ - frame_overhead_size_);
// Consumed all data. Exit.
if (bytes_to_drain_this_iteration == 0) {
break;
Expand Down Expand Up @@ -327,8 +329,7 @@ Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream)
} else {
ASSERT(frame_protector_);
// Check if we need to flush outstanding handshake bytes.
if (write_buffer_contains_handshake_bytes_) {
ASSERT(raw_write_buffer_.length() > 0);
if (raw_write_buffer_.length() > 0 && prev_bytes_to_drain_ == 0) {
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be helpful to add tests to cover the case where doHandshakeNextDone adds bytes and how this version behaves better than the case where we only set "write_buffer_contains_handshake_bytes_ = true" in one special case.

ENVOY_CONN_LOG(debug, "TSI: raw_write length {} end_stream {}", callbacks_->connection(),
raw_write_buffer_.length(), end_stream);
Network::IoResult result =
Expand All @@ -337,7 +338,6 @@ Network::IoResult TsiSocket::doWrite(Buffer::Instance& buffer, bool end_stream)
if (raw_write_buffer_.length() > 0) {
return {result.action_, 0, false};
}
write_buffer_contains_handshake_bytes_ = false;
}
return repeatProtectAndWrite(buffer, end_stream);
}
Expand Down
1 change: 0 additions & 1 deletion source/extensions/transport_sockets/alts/tsi_socket.h
Original file line number Diff line number Diff line change
Expand Up @@ -117,7 +117,6 @@ class TsiSocket : public Network::TransportSocket,
bool handshake_complete_{};
bool end_stream_read_{};
bool read_error_{};
bool write_buffer_contains_handshake_bytes_{};
uint64_t prev_bytes_to_drain_{};
};

Expand Down
47 changes: 35 additions & 12 deletions test/extensions/transport_sockets/alts/alts_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,12 +60,14 @@ class CapturingHandshakerService : public grpc::gcp::HandshakerService::Service
while (stream->Read(&request)) {
if (request.has_client_start()) {
client_versions = request.client_start().rpc_versions();
client_max_frame_size = request.client_start().max_frame_size();
// Sets response to make first request successful.
response.set_out_frames(kClientInitFrame);
response.set_bytes_consumed(0);
response.mutable_status()->set_code(grpc::StatusCode::OK);
} else if (request.has_server_start()) {
server_versions = request.server_start().rpc_versions();
server_max_frame_size = request.server_start().max_frame_size();
response.mutable_status()->set_code(grpc::StatusCode::CANCELLED);
}
stream->Write(response);
Expand All @@ -81,9 +83,8 @@ class CapturingHandshakerService : public grpc::gcp::HandshakerService::Service
grpc::gcp::RpcProtocolVersions client_versions;
grpc::gcp::RpcProtocolVersions server_versions;

// TODO(yihuazhang): Test maximum frame size stored in handshake messages
// after updating test/core/tsi/alts/fake_handshaker/handshaker.proto to
// support maximum frame size negotiation.
size_t client_max_frame_size{0};
size_t server_max_frame_size{0};
};

class AltsIntegrationTestBase : public Event::TestUsingSimulatedTime,
Expand Down Expand Up @@ -178,19 +179,19 @@ class AltsIntegrationTestBase : public Event::TestUsingSimulatedTime,
fake_handshaker_server_thread_->join();
}

Network::ClientConnectionPtr makeAltsConnection() {
Network::Address::InstanceConstSharedPtr address = getAddress(version_, lookupPort("http"));
Network::TransportSocketPtr makeAltsTransportSocket() {
auto client_transport_socket = client_alts_->createTransportSocket(nullptr);
client_tsi_socket_ = dynamic_cast<TsiSocket*>(client_transport_socket.get());
client_tsi_socket_->setActualFrameSizeToUse(16384);
client_tsi_socket_->setFrameOverheadSize(4);
return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(),
std::move(client_transport_socket), nullptr);
return client_transport_socket;
}

void verifyActualFrameSizeToUse() {
EXPECT_NE(client_tsi_socket_, nullptr);
EXPECT_EQ(client_tsi_socket_->actualFrameSizeToUse(), 16384);
Network::ClientConnectionPtr makeAltsConnection() {
auto client_transport_socket = makeAltsTransportSocket();
Network::Address::InstanceConstSharedPtr address = getAddress(version_, lookupPort("http"));
return dispatcher_->createClientConnection(address, Network::Address::InstanceConstSharedPtr(),
std::move(client_transport_socket), nullptr);
}

std::string fakeHandshakerServerAddress(bool connect_to_handshaker) {
Expand Down Expand Up @@ -245,7 +246,21 @@ TEST_P(AltsIntegrationTestValidPeer, RouterRequestAndResponseWithBodyNoBuffer) {
return makeAltsConnection();
};
testRouterRequestAndResponseWithBody(1024, 512, false, false, &creator);
verifyActualFrameSizeToUse();
}

TEST_P(AltsIntegrationTestValidPeer, RouterRequestAndResponseWithBodyRawHttp) {
autonomous_upstream_ = true;
initialize();
std::string response;
sendRawHttpAndWaitForResponse(lookupPort("http"),
"GET / HTTP/1.1\r\n"
"Host: foo.com\r\n"
"Foo: bar\r\n"
"User-Agent: public\r\n"
"User-Agent: 123\r\n"
"Eep: baz\r\n\r\n",
&response, true, makeAltsTransportSocket());
EXPECT_THAT(response, testing::StartsWith("HTTP/1.1 200 OK\r\n"));
}

class AltsIntegrationTestEmptyPeer : public AltsIntegrationTestBase {
Expand All @@ -267,7 +282,6 @@ TEST_P(AltsIntegrationTestEmptyPeer, RouterRequestAndResponseWithBodyNoBuffer) {
return makeAltsConnection();
};
testRouterRequestAndResponseWithBody(1024, 512, false, false, &creator);
verifyActualFrameSizeToUse();
}

class AltsIntegrationTestClientInvalidPeer : public AltsIntegrationTestBase {
Expand Down Expand Up @@ -370,6 +384,15 @@ TEST_P(AltsIntegrationTestCapturingHandshaker, CheckAltsVersion) {
EXPECT_NE(0, capturing_handshaker_service_->client_versions.min_rpc_version().minor());
}

// Verifies that handshake request should include max frame size.
TEST_P(AltsIntegrationTestCapturingHandshaker, CheckMaxFrameSize) {
initialize();
codec_client_ = makeRawHttpConnection(makeAltsConnection(), absl::nullopt);
EXPECT_FALSE(codec_client_->connected());
EXPECT_EQ(capturing_handshaker_service_->client_max_frame_size, 16384);
EXPECT_EQ(capturing_handshaker_service_->server_max_frame_size, 16384);
}

} // namespace
} // namespace Alts
} // namespace TransportSockets
Expand Down
8 changes: 4 additions & 4 deletions test/extensions/transport_sockets/alts/tsi_socket_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -171,8 +171,8 @@ class TsiSocketTest : public testing::Test {
EXPECT_EQ(0L, client_.read_buffer_.length());

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down Expand Up @@ -342,8 +342,8 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedData) {
client_to_server_.add(makeFakeTsiFrame(ClientToServerData));

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 17UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down Expand Up @@ -378,8 +378,8 @@ TEST_F(TsiSocketTest, HandshakeWithUnusedDataAndEndOfStream) {
buffer.move(client_to_server_);
return result;
}));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
expectIoResult({Network::PostIoAction::KeepOpen, 17UL, true},
server_.tsi_socket_->doRead(server_.read_buffer_));
EXPECT_EQ(makeFakeTsiFrame("SERVER_FINISHED"), server_to_client_.toString());
Expand Down Expand Up @@ -841,7 +841,6 @@ TEST_F(TsiSocketTest, DoWriteOutstandingHandshakeData) {
EXPECT_EQ(0L, client_.read_buffer_.length());

EXPECT_CALL(*server_.raw_socket_, doRead(_));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));

// Write the first part of handshake data (14 bytes).
EXPECT_CALL(*server_.raw_socket_, doWrite(_, false))
Expand All @@ -850,6 +849,7 @@ TEST_F(TsiSocketTest, DoWriteOutstandingHandshakeData) {
server_to_client_.move(buffer, 14);
return result;
}));
EXPECT_CALL(server_.callbacks_, raiseEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(*server_.raw_socket_, doRead(_));
expectIoResult({Network::PostIoAction::KeepOpen, 0UL, false},
server_.tsi_socket_->doRead(server_.read_buffer_));
Expand Down
9 changes: 5 additions & 4 deletions test/integration/base_integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -389,9 +389,9 @@ void BaseIntegrationTest::createApiTestServer(const ApiFilesystemConfig& api_fil
port_names, validator_config, allow_lds_rejection);
}

void BaseIntegrationTest::sendRawHttpAndWaitForResponse(int port, const char* raw_http,
std::string* response,
bool disconnect_after_headers_complete) {
void BaseIntegrationTest::sendRawHttpAndWaitForResponse(
int port, const char* raw_http, std::string* response, bool disconnect_after_headers_complete,
Network::TransportSocketPtr transport_socket) {
auto connection = createConnectionDriver(
port, raw_http,
[response, disconnect_after_headers_complete](Network::ClientConnection& client,
Expand All @@ -400,7 +400,8 @@ void BaseIntegrationTest::sendRawHttpAndWaitForResponse(int port, const char* ra
if (disconnect_after_headers_complete && response->find("\r\n\r\n") != std::string::npos) {
client.close(Network::ConnectionCloseType::NoFlush);
}
});
},
std::move(transport_socket));

connection->run();
}
Expand Down
8 changes: 5 additions & 3 deletions test/integration/base_integration_test.h
Original file line number Diff line number Diff line change
Expand Up @@ -273,7 +273,8 @@ class BaseIntegrationTest : protected Logger::Loggable<Logger::Id::testing> {
* @param if the connection should be terminated once '\r\n\r\n' has been read.
**/
void sendRawHttpAndWaitForResponse(int port, const char* raw_http, std::string* response,
bool disconnect_after_headers_complete = false);
bool disconnect_after_headers_complete = false,
Network::TransportSocketPtr transport_socket = nullptr);

/**
* Helper to create ConnectionDriver.
Expand All @@ -284,10 +285,11 @@ class BaseIntegrationTest : protected Logger::Loggable<Logger::Id::testing> {
**/
std::unique_ptr<RawConnectionDriver> createConnectionDriver(
uint32_t port, const std::string& initial_data,
std::function<void(Network::ClientConnection&, const Buffer::Instance&)>&& data_callback) {
std::function<void(Network::ClientConnection&, const Buffer::Instance&)>&& data_callback,
Network::TransportSocketPtr transport_socket = nullptr) {
Buffer::OwnedImpl buffer(initial_data);
return std::make_unique<RawConnectionDriver>(port, buffer, data_callback, version_,
*dispatcher_);
*dispatcher_, std::move(transport_socket));
}

/**
Expand Down