Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions source/extensions/transport_sockets/http_11_proxy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ envoy_cc_library(
"//source/common/buffer:buffer_lib",
"//source/common/common:scalar_to_byte_vector_lib",
"//source/common/common:utility_lib",
"//source/common/http/http1:balsa_parser_lib",
"//source/common/http/http1:legacy_parser_lib",
"//source/common/network:address_lib",
"//source/extensions/transport_sockets/common:passthrough_lib",
],
Expand Down
26 changes: 18 additions & 8 deletions source/extensions/transport_sockets/http_11_proxy/connect.cc
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,25 @@
#include "source/common/common/scalar_to_byte_vector.h"
#include "source/common/common/utility.h"
#include "source/common/network/address_impl.h"
#include "source/common/runtime/runtime_features.h"

namespace Envoy {
namespace Extensions {
namespace TransportSockets {
namespace Http11Connect {

bool UpstreamHttp11ConnectSocket::isValidConnectResponse(Buffer::Instance& buffer) {
SelfContainedParser parser;
while (parser.parser().getStatus() == Http::Http1::ParserStatus::Ok &&
!parser.headersComplete() && buffer.length() != 0) {
auto slice = buffer.frontSlice();
int parsed = parser.parser().execute(static_cast<const char*>(slice.mem_), slice.len_);
buffer.drain(parsed);
}
return parser.parser().getStatus() != Http::Http1::ParserStatus::Error &&
parser.headersComplete() && parser.parser().statusCode() == 200;
}

UpstreamHttp11ConnectSocket::UpstreamHttp11ConnectSocket(
Network::TransportSocketPtr&& transport_socket,
Network::TransportSocketOptionsConstSharedPtr options)
Expand Down Expand Up @@ -45,8 +58,8 @@ Network::IoResult UpstreamHttp11ConnectSocket::doWrite(Buffer::Instance& buffer,

Network::IoResult UpstreamHttp11ConnectSocket::doRead(Buffer::Instance& buffer) {
if (need_to_strip_connect_response_) {
// Limit the CONNECT response headers to an arbitrary 200 bytes.
constexpr uint32_t MAX_RESPONSE_HEADER_SIZE = 200;
// Limit the CONNECT response headers to an arbitrary 2000 bytes.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know this is arbitrary but why the increase?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I decided the prior arbitrary limit was too small :-)
Kuat has been talking about adding extra headers to the CONNECT headers, and Rafal was adding test headers in his integration test, so I want to handle more than just a raw response.

constexpr uint32_t MAX_RESPONSE_HEADER_SIZE = 2000;
char peek_buf[MAX_RESPONSE_HEADER_SIZE];
Api::IoCallUint64Result result =
callbacks_->ioHandle().recv(peek_buf, MAX_RESPONSE_HEADER_SIZE, MSG_PEEK);
Expand All @@ -68,16 +81,13 @@ Network::IoResult UpstreamHttp11ConnectSocket::doRead(Buffer::Instance& buffer)
ENVOY_CONN_LOG(trace, "failed to drain CONNECT header", callbacks_->connection());
return {Network::PostIoAction::Close, 0, false};
}
// Note this is not in any way proper HTTP/1.1 parsing.
// Before this is used with any untrusted upstream, proper checks should
// be done rather than this.
if (!absl::StartsWith(peek_data, "HTTP/1.1 200")) {
ENVOY_CONN_LOG(trace, "Response does not match strict connect checks",
// Make sure the response is a valid connect response and all the data is consumed.
if (!isValidConnectResponse(buffer) || buffer.length() != 0) {
ENVOY_CONN_LOG(trace, "Response does not appear to be a successful CONNECT upgrade",
callbacks_->connection());
return {Network::PostIoAction::Close, 0, false};
}
ENVOY_CONN_LOG(trace, "Successfully stripped CONNECT header", callbacks_->connection());
buffer.drain(buffer.length());
need_to_strip_connect_response_ = false;
}
return transport_socket_->doRead(buffer);
Expand Down
42 changes: 42 additions & 0 deletions source/extensions/transport_sockets/http_11_proxy/connect.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "source/common/buffer/buffer_impl.h"
#include "source/common/common/logger.h"
#include "source/common/http/http1/balsa_parser.h"
#include "source/extensions/transport_sockets/common/passthrough.h"

namespace Envoy {
Expand All @@ -18,6 +19,8 @@ namespace Http11Connect {
class UpstreamHttp11ConnectSocket : public TransportSockets::PassthroughSocket,
public Logger::Loggable<Logger::Id::connection> {
public:
static bool isValidConnectResponse(Buffer::Instance& buffer);

UpstreamHttp11ConnectSocket(Network::TransportSocketPtr&& transport_socket,
Network::TransportSocketOptionsConstSharedPtr options);

Expand Down Expand Up @@ -48,6 +51,45 @@ class UpstreamHttp11ConnectSocketFactory : public PassthroughFactory {
Network::TransportSocketOptionsConstSharedPtr options) const override;
};

// This is a utility class for isValidConnectResponse. It is only exposed for
// coverage testing purposes. See isValidConnectResponse for intended use.
class SelfContainedParser : public Http::Http1::ParserCallbacks {
public:
SelfContainedParser() : parser_(Http::Http1::MessageType::Response, this, 2000) {}
Http::Http1::CallbackResult onMessageBegin() override {
return Http::Http1::CallbackResult::Success;
}
Http::Http1::CallbackResult onUrl(const char*, size_t) override {
return Http::Http1::CallbackResult::Success;
}
Http::Http1::CallbackResult onStatus(const char*, size_t) override {
return Http::Http1::CallbackResult::Success;
}
Http::Http1::CallbackResult onHeaderField(const char*, size_t) override {
return Http::Http1::CallbackResult::Success;
}
Http::Http1::CallbackResult onHeaderValue(const char*, size_t) override {
return Http::Http1::CallbackResult::Success;
}
Http::Http1::CallbackResult onHeadersComplete() override {
headers_complete_ = true;
parser_.pause();
return Http::Http1::CallbackResult::Success;
}
void bufferBody(const char*, size_t) override {}
Http::Http1::CallbackResult onMessageComplete() override {
return Http::Http1::CallbackResult::Success;
}
void onChunkHeader(bool) override {}

bool headersComplete() const { return headers_complete_; }
Http::Http1::BalsaParser& parser() { return parser_; }

private:
bool headers_complete_ = false;
Http::Http1::BalsaParser parser_;
};

} // namespace Http11Connect
} // namespace TransportSockets
} // namespace Extensions
Expand Down
1 change: 1 addition & 0 deletions test/extensions/transport_sockets/http_11_proxy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ envoy_extension_cc_test(
"//test/mocks/network:io_handle_mocks",
"//test/mocks/network:network_mocks",
"//test/mocks/network:transport_socket_mocks",
"//test/test_common:test_runtime_lib",
],
)

Expand Down
67 changes: 52 additions & 15 deletions test/extensions/transport_sockets/http_11_proxy/connect_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -83,8 +83,8 @@ class Http11ConnectTest : public testing::TestWithParam<Network::Address::IpVers
std::make_shared<NiceMock<Envoy::Ssl::MockConnectionInfo>>()};
};

// Test injects CONNECT only once
TEST_P(Http11ConnectTest, InjectesHeaderOnlyOnce) {
// Test injects CONNECT only once
TEST_P(Http11ConnectTest, InjectsHeaderOnlyOnce) {
initialize();

EXPECT_CALL(io_handle_, write(BufferStringEqual(connect_data_.toString())))
Expand Down Expand Up @@ -219,16 +219,19 @@ TEST_P(Http11ConnectTest, StipsHeaderOnce) {

std::string connect("HTTP/1.1 200 OK\r\n\r\n");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Invoke([&initial_data](void* buffer, size_t, int) {
memcpy(buffer, initial_data.data(), initial_data.length());
return Api::IoCallUint64Result(initial_data.length(),
Api::IoErrorPtr(nullptr, [](Api::IoError*) {}));
}));
absl::optional<uint64_t> expected_bytes(connect.length());
EXPECT_CALL(io_handle_, read(_, expected_bytes))
.WillOnce(Return(ByMove(Api::IoCallUint64Result(
connect.length(), Api::IoErrorPtr(nullptr, [](Api::IoError*) {})))));
.WillOnce(Invoke([&](Buffer::Instance& buffer, absl::optional<uint64_t>) {
buffer.add(connect);
return Api::IoCallUint64Result(connect.length(),
Api::IoErrorPtr(nullptr, [](Api::IoError*) {}));
}));
EXPECT_CALL(*inner_socket_, doRead(_))
.WillOnce(Return(Network::IoResult{Network::PostIoAction::KeepOpen, 1, false}));
Buffer::OwnedImpl buffer("");
Expand All @@ -241,7 +244,7 @@ TEST_P(Http11ConnectTest, InsufficientData) {

std::string connect("HTTP/1.1 200 OK\r\n\r");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Invoke([&initial_data](void* buffer, size_t, int) {
memcpy(buffer, initial_data.data(), initial_data.length());
return Api::IoCallUint64Result(initial_data.length(),
Expand All @@ -258,7 +261,7 @@ TEST_P(Http11ConnectTest, PeekFail) {

std::string connect("HTTP/1.1 200 OK\r\n\r\n");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Return(ByMove(
Api::IoCallUint64Result({}, Api::IoErrorPtr(new Network::IoSocketError(EADDRNOTAVAIL),
Network::IoSocketError::deleteIoError)))));
Expand All @@ -276,7 +279,7 @@ TEST_P(Http11ConnectTest, ReadFail) {

std::string connect("HTTP/1.1 200 OK\r\n\r\n");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Invoke([&initial_data](void* buffer, size_t, int) {
memcpy(buffer, initial_data.data(), initial_data.length());
return Api::IoCallUint64Result(initial_data.length(),
Expand All @@ -300,7 +303,7 @@ TEST_P(Http11ConnectTest, ShortRead) {

std::string connect("HTTP/1.1 200 OK\r\n\r\n");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Invoke([&initial_data](void* buffer, size_t, int) {
memcpy(buffer, initial_data.data(), initial_data.length());
return Api::IoCallUint64Result(initial_data.length(),
Expand All @@ -317,13 +320,13 @@ TEST_P(Http11ConnectTest, ShortRead) {
EXPECT_EQ(Network::PostIoAction::Close, result.action_);
}

// If headers exceed 200 bytes, read fails.
// If headers exceed 2000 bytes, read fails.
TEST_P(Http11ConnectTest, LongHeaders) {
initialize();

EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)).WillOnce(Invoke([](void* buffer, size_t, int) {
memset(buffer, 0, 200);
return Api::IoCallUint64Result(200, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}));
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)).WillOnce(Invoke([](void* buffer, size_t, int) {
memset(buffer, 0, 2000);
return Api::IoCallUint64Result(2000, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}));
}));
EXPECT_CALL(io_handle_, read(_, _)).Times(0);
EXPECT_CALL(*inner_socket_, doRead(_)).Times(0);
Expand All @@ -339,7 +342,7 @@ TEST_P(Http11ConnectTest, InvalidResponse) {

std::string connect("HTTP/1.1 404 Not Found\r\n\r\n");
std::string initial_data(connect + "follow up data");
EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK))
EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK))
.WillOnce(Invoke([&initial_data](void* buffer, size_t, int) {
memcpy(buffer, initial_data.data(), initial_data.length());
return Api::IoCallUint64Result(initial_data.length(),
Expand All @@ -353,7 +356,7 @@ TEST_P(Http11ConnectTest, InvalidResponse) {
EXPECT_CALL(*inner_socket_, doRead(_)).Times(0);

Buffer::OwnedImpl buffer("");
EXPECT_LOG_CONTAINS("trace", "Response does not match strict connect checks", {
EXPECT_LOG_CONTAINS("trace", "Response does not appear to be a successful CONNECT upgrade", {
auto result = connect_socket_->doRead(buffer);
EXPECT_EQ(Network::PostIoAction::Close, result.action_);
});
Expand Down Expand Up @@ -382,6 +385,40 @@ TEST_F(SocketFactoryTest, CreateSocketReturnsNullWhenInnerFactoryReturnsNull) {
ASSERT_EQ(nullptr, factory_->createTransportSocket(nullptr, nullptr));
}

TEST(ParseTest, TestValidResponse) {
{
Buffer::OwnedImpl buffer("HTTP/1.0 200 OK\r\n\r\n");
ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer));
EXPECT_EQ(buffer.length(), 0);
}
{
Buffer::OwnedImpl buffer("HTTP/1.0 200 OK\r\nFoo: Bar\r\n\r\n");
ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer));
EXPECT_EQ(buffer.length(), 0);
}
{
Buffer::OwnedImpl buffer("HTTP/1.1 200 OK \r\n\r\nasdf");
ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer));
EXPECT_EQ(buffer.length(), 4);
EXPECT_EQ(buffer.toString(), "asdf");
}
{
Buffer::OwnedImpl buffer("HTTP/1.0 300 OK\r\n\r\n");
ASSERT_FALSE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer));
}
}

// The SelfContainedParser is only intended for header parsing but for coverage,
// test a request with a body.
TEST(ParseTest, CoverResponseBody) {
std::string headers = "HTTP/1.0 200 OK\r\ncontent-length: 2\r\n\r\n";
std::string body = "ab";

SelfContainedParser parser;
parser.parser().execute(headers.c_str(), headers.length());
parser.parser().execute(body.c_str(), body.length());
}

} // namespace
} // namespace Http11Connect
} // namespace TransportSockets
Expand Down