diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index 011949996723b..4e9103ab9acd7 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -91,7 +91,6 @@ envoy_cc_library( envoy_cc_library( name = "metadata_lib", - srcs = ["metadata.cc"], hdrs = ["metadata.h"], external_deps = ["abseil_optional"], deps = [ diff --git a/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc index 0dee09aed8c0b..1b7b36cc5a7c4 100644 --- a/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc @@ -144,9 +144,10 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta } while (num_headers-- > 0) { - std::string key = drainVarString(buffer, header_size, "header key"); - std::string value = drainVarString(buffer, header_size, "header value"); - metadata.addHeader(Header(key, value)); + const Http::LowerCaseString key = + Http::LowerCaseString(drainVarString(buffer, header_size, "header key")); + const std::string value = drainVarString(buffer, header_size, "header value"); + metadata.headers().addCopy(key, value); } } @@ -172,7 +173,7 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet throw EnvoyException(fmt::format("invalid thrift header transport message size {}", msg_size)); } - const HeaderMap& headers = metadata.headers(); + const Http::HeaderMap& headers = metadata.headers(); if (headers.size() > MaxHeadersSize / 2) { // Each header takes a minimum of 2 bytes, yielding this limit. throw EnvoyException( @@ -205,10 +206,14 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet // Num headers BufferHelper::writeVarIntI32(header_buffer, static_cast(headers.size())); - for (const Header& header : headers) { - writeVarString(header_buffer, header.key()); - writeVarString(header_buffer, header.value()); - } + headers.iterate( + [](const Http::HeaderEntry& header, void* context) -> Http::HeaderMap::Iterate { + Buffer::Instance* hb = static_cast(context); + writeVarString(*hb, header.key().getStringView()); + writeVarString(*hb, header.value().getStringView()); + return Http::HeaderMap::Iterate::Continue; + }, + &header_buffer); } uint64_t header_size = header_buffer.length(); @@ -286,7 +291,7 @@ std::string HeaderTransportImpl::drainVarString(Buffer::Instance& buffer, int32_ return value; } -void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const std::string& str) { +void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const absl::string_view str) { std::string::size_type len = str.length(); if (len > static_cast(std::numeric_limits::max())) { throw EnvoyException(fmt::format("header string too long: {}", len)); @@ -296,7 +301,7 @@ void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const std::st if (len == 0) { return; } - buffer.add(str); + buffer.add(str.data(), len); } class HeaderTransportConfigFactory : public TransportFactoryBase { diff --git a/source/extensions/filters/network/thrift_proxy/header_transport_impl.h b/source/extensions/filters/network/thrift_proxy/header_transport_impl.h index 88807904018f9..02c7d051cb503 100644 --- a/source/extensions/filters/network/thrift_proxy/header_transport_impl.h +++ b/source/extensions/filters/network/thrift_proxy/header_transport_impl.h @@ -43,7 +43,7 @@ class HeaderTransportImpl : public Transport { static int32_t drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size, const char* desc); static std::string drainVarString(Buffer::Instance& buffer, int32_t& header_size, const char* desc); - static void writeVarString(Buffer::Instance& buffer, const std::string& str); + static void writeVarString(Buffer::Instance& buffer, const absl::string_view str); void setException(AppExceptionType type, std::string reason) { if (exception_.has_value()) { diff --git a/source/extensions/filters/network/thrift_proxy/metadata.cc b/source/extensions/filters/network/thrift_proxy/metadata.cc deleted file mode 100644 index 1176773478840..0000000000000 --- a/source/extensions/filters/network/thrift_proxy/metadata.cc +++ /dev/null @@ -1,47 +0,0 @@ -#include "extensions/filters/network/thrift_proxy/metadata.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace ThriftProxy { - -HeaderMap::HeaderMap(const std::initializer_list>& values) { - for (auto& value : values) { - headers_.emplace_back(Header(value.first, value.second)); - } -} - -HeaderMap::HeaderMap(const HeaderMap& rhs) { - for (const auto& header : rhs.headers_) { - headers_.emplace_back(header); - } -} - -bool HeaderMap::operator==(const HeaderMap& rhs) const { - if (headers_.size() != rhs.headers_.size()) { - return false; - } - - for (auto i = headers_.begin(), j = rhs.headers_.begin(); i != headers_.end(); ++i, ++j) { - if (i->key() != j->key() || i->value() != j->value()) { - return false; - } - } - - return true; -} - -Header* HeaderMap::get(const std::string& key) { - for (Header& header : headers_) { - if (header.key() == key) { - return &header; - } - } - - return nullptr; -} - -} // namespace ThriftProxy -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index a7025739df36d..e9498d88eef43 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -7,6 +7,7 @@ #include #include "common/common/macros.h" +#include "common/http/header_map_impl.h" #include "extensions/filters/network/thrift_proxy/thrift.h" @@ -17,82 +18,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -/** - * Header is a name-value pair in Thrift transport or protocol headers. - */ -class Header { -public: - Header(const std::string key, const std::string value) : key_(key), value_(value) {} - Header(const Header& rhs) : key_(rhs.key_), value_(rhs.value_) {} - - const std::string& key() const { return key_; } - const std::string& value() const { return value_; } - -private: - std::string key_; - std::string value_; -}; - -// TODO(zuercher): replace this with Http::HeaderMap[Impl] -/* - * HeaderMap contains Thrift transport and/or protocol-level headers. - */ -class HeaderMap { -public: - HeaderMap() {} - HeaderMap(const std::initializer_list>& values); - HeaderMap(const HeaderMap& rhs); - - /** - * @return true if the HeaderMap is empty - */ - bool empty() const { return headers_.empty(); } - - /** - * @return uint32_t the number of headers in the map - */ - uint32_t size() const { return headers_.size(); } - - /** - * @param header Header to move into the HeaderMap - */ - void add(Header&& header) { headers_.emplace_back(std::move(header)); } - - /** - * Clears all Headers from the HeaderMap. - */ - void clear() { headers_.clear(); } - - /** - * Retrieves a Header from the HeaderMap. - * @param key std::string containing the key to lookup - * @return Header* corresponding to key or nullptr if not found. - */ - Header* get(const std::string& key); - - /** - * Const iterators for the HeaderMap. - */ - std::list
::const_iterator begin() const noexcept { return headers_.begin(); } - std::list
::const_iterator end() const noexcept { return headers_.end(); } - std::list
::const_iterator cbegin() const noexcept { return headers_.cbegin(); } - std::list
::const_iterator cend() const noexcept { return headers_.cend(); } - - /** - * For testing. Equality is based on equality of the backing list. This is an exact match - * comparison (order matters). - */ - bool operator==(const HeaderMap& rhs) const; - - /** - * @return an empty HeaderMap - */ - static const HeaderMap& emptyHeaderMap() { CONSTRUCT_ON_FIRST_USE(HeaderMap, HeaderMap({})); } - -private: - std::list
headers_; -}; - /** * MessageMetadata encapsulates metadata about Thrift messages. The various fields are considered * optional since they may come from either the transport or protocol in some cases. Unless @@ -123,11 +48,11 @@ class MessageMetadata { MessageType messageType() const { return msg_type_.value(); } void setMessageType(MessageType msg_type) { msg_type_ = msg_type; } - void addHeader(Header&& header) { headers_.add(std::move(header)); } /** * @return HeaderMap of current headers (never throws) */ - const HeaderMap& headers() const { return headers_; } + const Http::HeaderMap& headers() const { return headers_; } + Http::HeaderMap& headers() { return headers_; } bool hasAppException() const { return app_ex_type_.has_value(); } void setAppException(AppExceptionType app_ex_type, const std::string& message) { @@ -143,7 +68,7 @@ class MessageMetadata { absl::optional method_name_{}; absl::optional seq_id_{}; absl::optional msg_type_{}; - HeaderMap headers_; + Http::HeaderMapImpl headers_; absl::optional app_ex_type_; absl::optional app_ex_msg_; }; diff --git a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc index dea7c694c5f1f..dad63019f04be 100644 --- a/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/binary_protocol_impl_test.cc @@ -36,7 +36,7 @@ class BinaryProtocolTest : public testing::Test { EXPECT_FALSE(metadata_.hasFrameSize()); EXPECT_FALSE(metadata_.hasProtocol()); EXPECT_FALSE(metadata_.hasAppException()); - EXPECT_TRUE(metadata_.headers().empty()); + EXPECT_EQ(metadata_.headers().size(), 0); } void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); } diff --git a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc index 70e476a5ad9a9..4ae0eeb66d4f7 100644 --- a/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/compact_protocol_impl_test.cc @@ -39,7 +39,7 @@ class CompactProtocolTest : public testing::Test { EXPECT_FALSE(metadata_.hasFrameSize()); EXPECT_FALSE(metadata_.hasProtocol()); EXPECT_FALSE(metadata_.hasAppException()); - EXPECT_TRUE(metadata_.headers().empty()); + EXPECT_EQ(metadata_.headers().size(), 0); } void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, 1); } diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index 40a7f318cbb9e..b5a1966c720ea 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -1040,8 +1040,6 @@ TEST(DecoderTest, OnDataHandlesStopIterationAndResumes) { Buffer::OwnedImpl buffer; bool underflow = true; - HeaderMap headers{{"test", "header"}}; - EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { metadata.setFrameSize(100); diff --git a/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc index 1bcc81120ce58..dac67e4cca0b0 100644 --- a/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc @@ -31,10 +31,11 @@ class MockBuffer : public Envoy::MockBuffer { MOCK_CONST_METHOD0(length, uint64_t()); }; -MessageMetadata mkMessageMetadata(uint32_t num_headers) { - MessageMetadata metadata; +MessageMetadataSharedPtr mkMessageMetadata(uint32_t num_headers) { + MessageMetadataSharedPtr metadata = std::make_shared(); + while (num_headers-- > 0) { - metadata.addHeader(Header("x", "y")); + metadata->headers().addCopy(Http::LowerCaseString("x"), "y"); } return metadata; } @@ -439,7 +440,7 @@ TEST(HeaderTransportTest, InfoBlock) { HeaderTransportImpl transport; Buffer::OwnedImpl buffer; MessageMetadata metadata; - metadata.addHeader(Header("not", "empty")); + metadata.headers().addCopy(Http::LowerCaseString("not"), "empty"); addInt32(buffer, 200); addInt16(buffer, 0x0FFF); @@ -459,16 +460,17 @@ TEST(HeaderTransportTest, InfoBlock) { addInt8(buffer, 0); // empty value addInt8(buffer, 0); // padding - HeaderMap expected_headers{ - {"not", "empty"}, - {"key", "value"}, - {"key2", std::string(128, 'x')}, - {"", ""}, - }; + Http::HeaderMapImpl expected_headers; + expected_headers.addCopy(Http::LowerCaseString("not"), "empty"); + expected_headers.addCopy(Http::LowerCaseString("key"), "value"); + expected_headers.addCopy(Http::LowerCaseString("key2"), std::string(128, 'x')); + expected_headers.addCopy(Http::LowerCaseString(""), ""); EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); EXPECT_THAT(metadata, HasFrameSize(38U)); - EXPECT_EQ(expected_headers, metadata.headers()); + + Http::HeaderMapImpl& actual_headers = dynamic_cast(metadata.headers()); + EXPECT_EQ(expected_headers, actual_headers); EXPECT_EQ(buffer.length(), 0); } @@ -530,13 +532,13 @@ TEST(HeaderTransportImpl, TestEncodeFrame) { // Too many headers { Buffer::OwnedImpl buffer; - MessageMetadata metadata = mkMessageMetadata(32769); - metadata.setProtocol(ProtocolType::Binary); + MessageMetadataSharedPtr metadata = mkMessageMetadata(32769); + metadata->setProtocol(ProtocolType::Binary); Buffer::OwnedImpl msg; msg.add("fake message"); - EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, *metadata, msg), EnvoyException, "invalid thrift header transport too many headers 32769"); } @@ -545,7 +547,7 @@ TEST(HeaderTransportImpl, TestEncodeFrame) { Buffer::OwnedImpl buffer; MessageMetadata metadata; metadata.setProtocol(ProtocolType::Binary); - metadata.addHeader(Header("key", std::string(32768, 'x'))); + metadata.headers().addCopy(Http::LowerCaseString("key"), std::string(32768, 'x')); Buffer::OwnedImpl msg; msg.add("fake message"); @@ -559,10 +561,10 @@ TEST(HeaderTransportImpl, TestEncodeFrame) { Buffer::OwnedImpl buffer; MessageMetadata metadata; metadata.setProtocol(ProtocolType::Binary); - metadata.addHeader(Header("k1", std::string(16384, 'x'))); - metadata.addHeader(Header("k2", std::string(16384, 'x'))); - metadata.addHeader(Header("k3", std::string(16384, 'x'))); - metadata.addHeader(Header("k4", std::string(16384, 'x'))); + metadata.headers().addCopy(Http::LowerCaseString("k1"), std::string(16384, 'x')); + metadata.headers().addCopy(Http::LowerCaseString("k2"), std::string(16384, 'x')); + metadata.headers().addCopy(Http::LowerCaseString("k3"), std::string(16384, 'x')); + metadata.headers().addCopy(Http::LowerCaseString("k4"), std::string(16384, 'x')); Buffer::OwnedImpl msg; msg.add("fake message"); @@ -620,8 +622,8 @@ TEST(HeaderTransportImpl, TestEncodeFrame) { MessageMetadata metadata; metadata.setProtocol(ProtocolType::Compact); metadata.setSequenceId(10); - metadata.addHeader(Header("key", "value")); - metadata.addHeader(Header("", "")); + metadata.headers().addCopy(Http::LowerCaseString("key"), "value"); + metadata.headers().addCopy(Http::LowerCaseString(""), ""); Buffer::OwnedImpl msg; msg.add("fake message"); diff --git a/test/extensions/filters/network/thrift_proxy/metadata_test.cc b/test/extensions/filters/network/thrift_proxy/metadata_test.cc index d06fa46e6b81f..17bd4a0aaeb8c 100644 --- a/test/extensions/filters/network/thrift_proxy/metadata_test.cc +++ b/test/extensions/filters/network/thrift_proxy/metadata_test.cc @@ -10,112 +10,6 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(HeaderTest, HeaderKeyIsNotTransformed) { - Header hdr("KEY", "VALUE"); - EXPECT_EQ(hdr.key(), "KEY"); - EXPECT_EQ(hdr.value(), "VALUE"); -} - -TEST(HeaderTest, HeaderIsCopyable) { - Header hdr("KEY", "VALUE"); - Header hdrCopy(hdr); - EXPECT_EQ(hdrCopy.key(), "KEY"); - EXPECT_EQ(hdrCopy.value(), "VALUE"); -} - -TEST(HeaderMapTest, AddHeaders) { - HeaderMap headers; - headers.add(Header("k", "v")); - - Header* hdr = headers.get("k"); - EXPECT_NE(hdr, nullptr); - EXPECT_EQ(hdr->key(), "k"); - EXPECT_EQ(hdr->value(), "v"); -} - -TEST(HeaderMapTest, GetHeaders) { - HeaderMap headers({ - {"a", "b"}, - {"c", "d"}, - {"e", "f"}, - }); - - EXPECT_EQ(headers.get("a")->value(), "b"); - EXPECT_EQ(headers.get("c")->value(), "d"); - EXPECT_EQ(headers.get("e")->value(), "f"); -} - -TEST(HeaderMapTest, Clear) { - HeaderMap headers({ - {"a", "b"}, - {"c", "d"}, - {"e", "f"}, - }); - - headers.clear(); - EXPECT_EQ(headers.get("a"), nullptr); - EXPECT_EQ(headers.get("c"), nullptr); - EXPECT_EQ(headers.get("e"), nullptr); -} - -TEST(HeaderMapTest, Size) { - HeaderMap headers({ - {"a", "b"}, - {"c", "d"}, - {"e", "f"}, - }); - - EXPECT_EQ(3U, headers.size()); -} - -TEST(HeaderMapTest, Equality) { - HeaderMap headers1({{"FIRST", "1"}, {"Second", "2"}}); - HeaderMap headers2({{"FIRST", "1"}, {"Second", "2"}}); - HeaderMap headers3({{"FIRST", "1"}}); - HeaderMap headers4({{"FIRST", "_"}, {"Second", "2"}}); - HeaderMap headers5({{"First", "1"}, {"Second", "2"}}); - - EXPECT_EQ(headers1, headers2); - EXPECT_EQ(headers2, headers1); - - EXPECT_FALSE(headers1 == headers3); - EXPECT_FALSE(headers3 == headers1); - - EXPECT_FALSE(headers1 == headers4); - EXPECT_FALSE(headers4 == headers1); - - EXPECT_FALSE(headers1 == headers5); - EXPECT_FALSE(headers5 == headers1); -} - -TEST(HeaderMapTest, Iteration) { - HeaderMap headers({{"first", "1"}, {"second", "2"}}); - - int i = 0; - for (const Header& header : headers) { - switch (i) { - case 0: - EXPECT_EQ("first", header.key()); - EXPECT_EQ("1", header.value()); - break; - case 1: - EXPECT_EQ("second", header.key()); - EXPECT_EQ("2", header.value()); - break; - default: - ASSERT(false); - break; - } - i++; - } -} - -TEST(HeaderMapTest, CopyConstructor) { - HeaderMap headers({{"first", "1"}, {"second", "2"}, {"third", "3"}}); - HeaderMap copy(headers); - EXPECT_EQ(copy, headers); -} - TEST(MessageMetadataTest, Fields) { MessageMetadata metadata; @@ -161,10 +55,9 @@ TEST(MessageMetadataTest, Fields) { TEST(MessageMetadataTest, Headers) { MessageMetadata metadata; - EXPECT_TRUE(metadata.headers().empty()); - - metadata.addHeader(Header("k", "v")); - EXPECT_FALSE(metadata.headers().empty()); + EXPECT_EQ(metadata.headers().size(), 0); + metadata.headers().addCopy(Http::LowerCaseString("k"), "v"); + EXPECT_EQ(metadata.headers().size(), 1); } } // namespace ThriftProxy diff --git a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index 621f826c46873..c8a4c260a4b98 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -45,7 +45,7 @@ class AutoProtocolTest : public testing::Test { EXPECT_FALSE(metadata_.hasFrameSize()); EXPECT_FALSE(metadata_.hasProtocol()); EXPECT_FALSE(metadata_.hasAppException()); - EXPECT_TRUE(metadata_.headers().empty()); + EXPECT_EQ(metadata_.headers().size(), 0); } void expectDefaultMetadata() { expectMetadata("-", MessageType::Oneway, -1); } diff --git a/test/extensions/filters/network/thrift_proxy/utility.h b/test/extensions/filters/network/thrift_proxy/utility.h index 21585f704c0b5..08953fe0cbf40 100644 --- a/test/extensions/filters/network/thrift_proxy/utility.h +++ b/test/extensions/filters/network/thrift_proxy/utility.h @@ -115,7 +115,7 @@ MATCHER(IsEmptyMetadata, "") { *result_listener << "has a message type of " << static_cast(arg.messageType()); return false; } - if (!arg.headers().empty()) { + if (arg.headers().size() > 0) { *result_listener << "has " << arg.headers().size() << " headers"; return false; } @@ -128,7 +128,7 @@ MATCHER(IsEmptyMetadata, "") { MATCHER_P(HasOnlyFrameSize, n, "") { return arg.hasFrameSize() && arg.frameSize() == n && !arg.hasProtocol() && !arg.hasMethodName() && - !arg.hasSequenceId() && !arg.hasMessageType() && arg.headers().empty() && + !arg.hasSequenceId() && !arg.hasMessageType() && arg.headers().size() == 0 && !arg.hasAppException(); } @@ -143,7 +143,7 @@ MATCHER_P(HasFrameSize, n, "") { MATCHER_P(HasProtocol, p, "") { return arg.hasProtocol() && arg.protocol() == p; } MATCHER_P(HasSequenceId, id, "") { return arg.hasSequenceId() && arg.sequenceId() == id; } -MATCHER(HasNoHeaders, "") { return arg.headers().empty(); } +MATCHER(HasNoHeaders, "") { return arg.headers().size() == 0; } MATCHER_P2(HasAppException, t, m, "") { if (!arg.hasAppException()) {