diff --git a/source/extensions/transport_sockets/http_11_proxy/BUILD b/source/extensions/transport_sockets/http_11_proxy/BUILD index 4e74a546c0f95..66b49fc2baa1b 100644 --- a/source/extensions/transport_sockets/http_11_proxy/BUILD +++ b/source/extensions/transport_sockets/http_11_proxy/BUILD @@ -33,6 +33,8 @@ envoy_cc_library( "//source/common/buffer:buffer_lib", "//source/common/common:scalar_to_byte_vector_lib", "//source/common/common:utility_lib", + "//source/common/http/http1:balsa_parser_lib", + "//source/common/http/http1:legacy_parser_lib", "//source/common/network:address_lib", "//source/extensions/transport_sockets/common:passthrough_lib", ], diff --git a/source/extensions/transport_sockets/http_11_proxy/connect.cc b/source/extensions/transport_sockets/http_11_proxy/connect.cc index ecadff478e830..889e8f0c51582 100644 --- a/source/extensions/transport_sockets/http_11_proxy/connect.cc +++ b/source/extensions/transport_sockets/http_11_proxy/connect.cc @@ -8,12 +8,25 @@ #include "source/common/common/scalar_to_byte_vector.h" #include "source/common/common/utility.h" #include "source/common/network/address_impl.h" +#include "source/common/runtime/runtime_features.h" namespace Envoy { namespace Extensions { namespace TransportSockets { namespace Http11Connect { +bool UpstreamHttp11ConnectSocket::isValidConnectResponse(Buffer::Instance& buffer) { + SelfContainedParser parser; + while (parser.parser().getStatus() == Http::Http1::ParserStatus::Ok && + !parser.headersComplete() && buffer.length() != 0) { + auto slice = buffer.frontSlice(); + int parsed = parser.parser().execute(static_cast(slice.mem_), slice.len_); + buffer.drain(parsed); + } + return parser.parser().getStatus() != Http::Http1::ParserStatus::Error && + parser.headersComplete() && parser.parser().statusCode() == 200; +} + UpstreamHttp11ConnectSocket::UpstreamHttp11ConnectSocket( Network::TransportSocketPtr&& transport_socket, Network::TransportSocketOptionsConstSharedPtr options) @@ -45,8 +58,8 @@ Network::IoResult UpstreamHttp11ConnectSocket::doWrite(Buffer::Instance& buffer, Network::IoResult UpstreamHttp11ConnectSocket::doRead(Buffer::Instance& buffer) { if (need_to_strip_connect_response_) { - // Limit the CONNECT response headers to an arbitrary 200 bytes. - constexpr uint32_t MAX_RESPONSE_HEADER_SIZE = 200; + // Limit the CONNECT response headers to an arbitrary 2000 bytes. + constexpr uint32_t MAX_RESPONSE_HEADER_SIZE = 2000; char peek_buf[MAX_RESPONSE_HEADER_SIZE]; Api::IoCallUint64Result result = callbacks_->ioHandle().recv(peek_buf, MAX_RESPONSE_HEADER_SIZE, MSG_PEEK); @@ -68,16 +81,13 @@ Network::IoResult UpstreamHttp11ConnectSocket::doRead(Buffer::Instance& buffer) ENVOY_CONN_LOG(trace, "failed to drain CONNECT header", callbacks_->connection()); return {Network::PostIoAction::Close, 0, false}; } - // Note this is not in any way proper HTTP/1.1 parsing. - // Before this is used with any untrusted upstream, proper checks should - // be done rather than this. - if (!absl::StartsWith(peek_data, "HTTP/1.1 200")) { - ENVOY_CONN_LOG(trace, "Response does not match strict connect checks", + // Make sure the response is a valid connect response and all the data is consumed. + if (!isValidConnectResponse(buffer) || buffer.length() != 0) { + ENVOY_CONN_LOG(trace, "Response does not appear to be a successful CONNECT upgrade", callbacks_->connection()); return {Network::PostIoAction::Close, 0, false}; } ENVOY_CONN_LOG(trace, "Successfully stripped CONNECT header", callbacks_->connection()); - buffer.drain(buffer.length()); need_to_strip_connect_response_ = false; } return transport_socket_->doRead(buffer); diff --git a/source/extensions/transport_sockets/http_11_proxy/connect.h b/source/extensions/transport_sockets/http_11_proxy/connect.h index fb51f70f96f8c..9279d864fb37c 100644 --- a/source/extensions/transport_sockets/http_11_proxy/connect.h +++ b/source/extensions/transport_sockets/http_11_proxy/connect.h @@ -5,6 +5,7 @@ #include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" +#include "source/common/http/http1/balsa_parser.h" #include "source/extensions/transport_sockets/common/passthrough.h" namespace Envoy { @@ -18,6 +19,8 @@ namespace Http11Connect { class UpstreamHttp11ConnectSocket : public TransportSockets::PassthroughSocket, public Logger::Loggable { public: + static bool isValidConnectResponse(Buffer::Instance& buffer); + UpstreamHttp11ConnectSocket(Network::TransportSocketPtr&& transport_socket, Network::TransportSocketOptionsConstSharedPtr options); @@ -48,6 +51,45 @@ class UpstreamHttp11ConnectSocketFactory : public PassthroughFactory { Network::TransportSocketOptionsConstSharedPtr options) const override; }; +// This is a utility class for isValidConnectResponse. It is only exposed for +// coverage testing purposes. See isValidConnectResponse for intended use. +class SelfContainedParser : public Http::Http1::ParserCallbacks { +public: + SelfContainedParser() : parser_(Http::Http1::MessageType::Response, this, 2000) {} + Http::Http1::CallbackResult onMessageBegin() override { + return Http::Http1::CallbackResult::Success; + } + Http::Http1::CallbackResult onUrl(const char*, size_t) override { + return Http::Http1::CallbackResult::Success; + } + Http::Http1::CallbackResult onStatus(const char*, size_t) override { + return Http::Http1::CallbackResult::Success; + } + Http::Http1::CallbackResult onHeaderField(const char*, size_t) override { + return Http::Http1::CallbackResult::Success; + } + Http::Http1::CallbackResult onHeaderValue(const char*, size_t) override { + return Http::Http1::CallbackResult::Success; + } + Http::Http1::CallbackResult onHeadersComplete() override { + headers_complete_ = true; + parser_.pause(); + return Http::Http1::CallbackResult::Success; + } + void bufferBody(const char*, size_t) override {} + Http::Http1::CallbackResult onMessageComplete() override { + return Http::Http1::CallbackResult::Success; + } + void onChunkHeader(bool) override {} + + bool headersComplete() const { return headers_complete_; } + Http::Http1::BalsaParser& parser() { return parser_; } + +private: + bool headers_complete_ = false; + Http::Http1::BalsaParser parser_; +}; + } // namespace Http11Connect } // namespace TransportSockets } // namespace Extensions diff --git a/test/extensions/transport_sockets/http_11_proxy/BUILD b/test/extensions/transport_sockets/http_11_proxy/BUILD index ed9fa885ae5fc..ee2b76e71d016 100644 --- a/test/extensions/transport_sockets/http_11_proxy/BUILD +++ b/test/extensions/transport_sockets/http_11_proxy/BUILD @@ -21,6 +21,7 @@ envoy_extension_cc_test( "//test/mocks/network:io_handle_mocks", "//test/mocks/network:network_mocks", "//test/mocks/network:transport_socket_mocks", + "//test/test_common:test_runtime_lib", ], ) diff --git a/test/extensions/transport_sockets/http_11_proxy/connect_test.cc b/test/extensions/transport_sockets/http_11_proxy/connect_test.cc index c439813e04906..745ceb5984b41 100644 --- a/test/extensions/transport_sockets/http_11_proxy/connect_test.cc +++ b/test/extensions/transport_sockets/http_11_proxy/connect_test.cc @@ -83,8 +83,8 @@ class Http11ConnectTest : public testing::TestWithParam>()}; }; -// Test injects CONNECT only once -TEST_P(Http11ConnectTest, InjectesHeaderOnlyOnce) { +// Test injects CONNECT only once +TEST_P(Http11ConnectTest, InjectsHeaderOnlyOnce) { initialize(); EXPECT_CALL(io_handle_, write(BufferStringEqual(connect_data_.toString()))) @@ -219,7 +219,7 @@ TEST_P(Http11ConnectTest, StipsHeaderOnce) { std::string connect("HTTP/1.1 200 OK\r\n\r\n"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Invoke([&initial_data](void* buffer, size_t, int) { memcpy(buffer, initial_data.data(), initial_data.length()); return Api::IoCallUint64Result(initial_data.length(), @@ -227,8 +227,11 @@ TEST_P(Http11ConnectTest, StipsHeaderOnce) { })); absl::optional expected_bytes(connect.length()); EXPECT_CALL(io_handle_, read(_, expected_bytes)) - .WillOnce(Return(ByMove(Api::IoCallUint64Result( - connect.length(), Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + .WillOnce(Invoke([&](Buffer::Instance& buffer, absl::optional) { + buffer.add(connect); + return Api::IoCallUint64Result(connect.length(), + Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + })); EXPECT_CALL(*inner_socket_, doRead(_)) .WillOnce(Return(Network::IoResult{Network::PostIoAction::KeepOpen, 1, false})); Buffer::OwnedImpl buffer(""); @@ -241,7 +244,7 @@ TEST_P(Http11ConnectTest, InsufficientData) { std::string connect("HTTP/1.1 200 OK\r\n\r"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Invoke([&initial_data](void* buffer, size_t, int) { memcpy(buffer, initial_data.data(), initial_data.length()); return Api::IoCallUint64Result(initial_data.length(), @@ -258,7 +261,7 @@ TEST_P(Http11ConnectTest, PeekFail) { std::string connect("HTTP/1.1 200 OK\r\n\r\n"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Return(ByMove( Api::IoCallUint64Result({}, Api::IoErrorPtr(new Network::IoSocketError(EADDRNOTAVAIL), Network::IoSocketError::deleteIoError))))); @@ -276,7 +279,7 @@ TEST_P(Http11ConnectTest, ReadFail) { std::string connect("HTTP/1.1 200 OK\r\n\r\n"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Invoke([&initial_data](void* buffer, size_t, int) { memcpy(buffer, initial_data.data(), initial_data.length()); return Api::IoCallUint64Result(initial_data.length(), @@ -300,7 +303,7 @@ TEST_P(Http11ConnectTest, ShortRead) { std::string connect("HTTP/1.1 200 OK\r\n\r\n"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Invoke([&initial_data](void* buffer, size_t, int) { memcpy(buffer, initial_data.data(), initial_data.length()); return Api::IoCallUint64Result(initial_data.length(), @@ -317,13 +320,13 @@ TEST_P(Http11ConnectTest, ShortRead) { EXPECT_EQ(Network::PostIoAction::Close, result.action_); } -// If headers exceed 200 bytes, read fails. +// If headers exceed 2000 bytes, read fails. TEST_P(Http11ConnectTest, LongHeaders) { initialize(); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)).WillOnce(Invoke([](void* buffer, size_t, int) { - memset(buffer, 0, 200); - return Api::IoCallUint64Result(200, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)).WillOnce(Invoke([](void* buffer, size_t, int) { + memset(buffer, 0, 2000); + return Api::IoCallUint64Result(2000, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); })); EXPECT_CALL(io_handle_, read(_, _)).Times(0); EXPECT_CALL(*inner_socket_, doRead(_)).Times(0); @@ -339,7 +342,7 @@ TEST_P(Http11ConnectTest, InvalidResponse) { std::string connect("HTTP/1.1 404 Not Found\r\n\r\n"); std::string initial_data(connect + "follow up data"); - EXPECT_CALL(io_handle_, recv(_, 200, MSG_PEEK)) + EXPECT_CALL(io_handle_, recv(_, 2000, MSG_PEEK)) .WillOnce(Invoke([&initial_data](void* buffer, size_t, int) { memcpy(buffer, initial_data.data(), initial_data.length()); return Api::IoCallUint64Result(initial_data.length(), @@ -353,7 +356,7 @@ TEST_P(Http11ConnectTest, InvalidResponse) { EXPECT_CALL(*inner_socket_, doRead(_)).Times(0); Buffer::OwnedImpl buffer(""); - EXPECT_LOG_CONTAINS("trace", "Response does not match strict connect checks", { + EXPECT_LOG_CONTAINS("trace", "Response does not appear to be a successful CONNECT upgrade", { auto result = connect_socket_->doRead(buffer); EXPECT_EQ(Network::PostIoAction::Close, result.action_); }); @@ -382,6 +385,40 @@ TEST_F(SocketFactoryTest, CreateSocketReturnsNullWhenInnerFactoryReturnsNull) { ASSERT_EQ(nullptr, factory_->createTransportSocket(nullptr, nullptr)); } +TEST(ParseTest, TestValidResponse) { + { + Buffer::OwnedImpl buffer("HTTP/1.0 200 OK\r\n\r\n"); + ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer)); + EXPECT_EQ(buffer.length(), 0); + } + { + Buffer::OwnedImpl buffer("HTTP/1.0 200 OK\r\nFoo: Bar\r\n\r\n"); + ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer)); + EXPECT_EQ(buffer.length(), 0); + } + { + Buffer::OwnedImpl buffer("HTTP/1.1 200 OK \r\n\r\nasdf"); + ASSERT_TRUE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer)); + EXPECT_EQ(buffer.length(), 4); + EXPECT_EQ(buffer.toString(), "asdf"); + } + { + Buffer::OwnedImpl buffer("HTTP/1.0 300 OK\r\n\r\n"); + ASSERT_FALSE(UpstreamHttp11ConnectSocket::isValidConnectResponse(buffer)); + } +} + +// The SelfContainedParser is only intended for header parsing but for coverage, +// test a request with a body. +TEST(ParseTest, CoverResponseBody) { + std::string headers = "HTTP/1.0 200 OK\r\ncontent-length: 2\r\n\r\n"; + std::string body = "ab"; + + SelfContainedParser parser; + parser.parser().execute(headers.c_str(), headers.length()); + parser.parser().execute(body.c_str(), body.length()); +} + } // namespace } // namespace Http11Connect } // namespace TransportSockets