diff --git a/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto b/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto index abf146643491a..d06d7c2aedff0 100644 --- a/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto +++ b/api/envoy/extensions/filters/network/thrift_proxy/v3/thrift_proxy.proto @@ -74,7 +74,7 @@ message Trds { string route_config_name = 2; } -// [#next-free-field: 10] +// [#next-free-field: 11] message ThriftProxy { option (udpa.annotations.versioning).previous_message_type = "envoy.config.filter.network.thrift_proxy.v2alpha1.ThriftProxy"; @@ -117,6 +117,13 @@ message ThriftProxy { // Configuration for :ref:`access logs ` // emitted by Thrift proxy. repeated config.accesslog.v3.AccessLog access_log = 9; + + // If set to true, Envoy will preserve the case of Thrift header keys instead of serializing them to + // lower case as per the default behavior. Note that NUL, CR and LF characters will also be preserved + // as mandated by the Thrift spec. + // + // More info: https://github.com/apache/thrift/commit/e165fa3c85d00cb984f4d9635ed60909a1266ce1. + bool header_keys_preserve_case = 10; } // ThriftFilter configures a Thrift filter. diff --git a/changelogs/current.yaml b/changelogs/current.yaml index 2d352ec3cdbc6..832c9cbbbc680 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -138,6 +138,9 @@ new_features: - area: thrift change: | added support for access logging. +- area: thrift + change: | + added support for preserving header keys. - area: thrift change: | introduced thrift configurable encoder and bidirectional filters, which allows peeking and modifying the thrift response message. diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index 2bd378213018f..9e39011645299 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -138,7 +138,8 @@ ConfigImpl::ConfigImpl( stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())), transport_(lookupTransport(config.transport())), proto_(lookupProtocol(config.protocol())), payload_passthrough_(config.payload_passthrough()), - max_requests_per_connection_(config.max_requests_per_connection().value()) { + max_requests_per_connection_(config.max_requests_per_connection().value()), + header_keys_preserve_case_(config.header_keys_preserve_case()) { if (config.thrift_filters().empty()) { ENVOY_LOG(debug, "using default router filter"); diff --git a/source/extensions/filters/network/thrift_proxy/config.h b/source/extensions/filters/network/thrift_proxy/config.h index 4cdf68523c119..4f37c6777f1c5 100644 --- a/source/extensions/filters/network/thrift_proxy/config.h +++ b/source/extensions/filters/network/thrift_proxy/config.h @@ -91,6 +91,7 @@ class ConfigImpl : public Config, const std::vector& accessLogs() const override { return access_logs_; } + bool headerKeysPreserveCase() const override { return header_keys_preserve_case_; } private: void processFilter( @@ -108,6 +109,7 @@ class ConfigImpl : public Config, const uint64_t max_requests_per_connection_{}; std::vector access_logs_; + const bool header_keys_preserve_case_; }; } // namespace ThriftProxy diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index dfcfe9a0ac6b4..b6f079f32b0dc 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -220,6 +220,8 @@ bool ConnectionManager::passthroughEnabled() const { return (*rpcs_.begin())->passthroughSupported(); } +bool ConnectionManager::headerKeysPreserveCase() const { return config_.headerKeysPreserveCase(); } + bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { upstream_buffer_.move(data); @@ -440,6 +442,10 @@ FilterStatus ConnectionManager::ResponseDecoder::setEnd() { return parent_.applyEncoderFilters(DecoderEvent::SetEnd, absl::any(), protocol_converter_); } +bool ConnectionManager::ResponseDecoder::headerKeysPreserveCase() const { + return parent_.parent_.headerKeysPreserveCase(); +} + void ConnectionManager::ActiveRpcDecoderFilter::continueDecoding() { const FilterStatus status = parent_.applyDecoderFilters(DecoderEvent::ContinueDecode, absl::any(), this); diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index f70e71d0707f7..fe4c5b3a1c177 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -43,6 +43,7 @@ class Config { virtual bool payloadPassthrough() const PURE; virtual uint64_t maxRequestsPerConnection() const PURE; virtual const std::vector& accessLogs() const PURE; + virtual bool headerKeysPreserveCase() const PURE; }; /** @@ -71,6 +72,7 @@ class ConnectionManager : public Network::ReadFilter, DecoderEventHandler& newDecoderEventHandler() override; bool passthroughEnabled() const override; bool isRequest() const override { return true; } + bool headerKeysPreserveCase() const override; private: struct ActiveRpc; @@ -116,6 +118,7 @@ class ConnectionManager : public Network::ReadFilter, DecoderEventHandler& newDecoderEventHandler() override { return *this; } bool passthroughEnabled() const override; bool isRequest() const override { return false; } + bool headerKeysPreserveCase() const override; void finalizeResponse(); diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index bcbb8a31997c4..17b79c0a5d712 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -436,7 +436,8 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) { if (!frame_started_) { // Look for start of next frame. if (!metadata_) { - metadata_ = std::make_shared(callbacks_.isRequest()); + metadata_ = std::make_shared(callbacks_.isRequest(), + callbacks_.headerKeysPreserveCase()); } if (!transport_.decodeFrameStart(data, *metadata_)) { diff --git a/source/extensions/filters/network/thrift_proxy/decoder.h b/source/extensions/filters/network/thrift_proxy/decoder.h index a1736010f1555..0508464a07b40 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.h +++ b/source/extensions/filters/network/thrift_proxy/decoder.h @@ -203,6 +203,11 @@ class DecoderCallbacks { * See https://github.com/apache/thrift/blob/master/lib/ts/thrift.d.ts#L68. */ virtual bool isRequest() const PURE; + + /** + * @return True if payload header keys should be treated as case-sensitive. + */ + virtual bool headerKeysPreserveCase() const PURE; }; /** diff --git a/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc index c10cf2dd21f62..4a5155f4dd012 100644 --- a/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc @@ -3,6 +3,7 @@ #include #include "envoy/common/exception.h" +#include "envoy/http/header_formatter.h" #include "source/common/buffer/buffer_impl.h" #include "source/extensions/filters/network/thrift_proxy/buffer_helper.h" @@ -132,6 +133,8 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta } const bool is_request = metadata.isRequest(); + auto formatter = + is_request ? metadata.requestHeaders().formatter() : metadata.responseHeaders().formatter(); while (header_size > 0) { // Attempt to read info blocks @@ -150,9 +153,13 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta while (num_headers-- > 0) { std::string key_string = drainVarString(buffer, header_size, "header key"); + if (formatter) { + formatter->processKey(key_string); + } // LowerCaseString doesn't allow '\0', '\n', and '\r'. key_string = absl::StrReplaceAll(key_string, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}}); + const Http::LowerCaseString key = Http::LowerCaseString(key_string); const std::string value = drainVarString(buffer, header_size, "header value"); @@ -221,20 +228,22 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet // Num headers BufferHelper::writeVarIntI32(header_buffer, static_cast(headers_size)); + auto formatter = metadata.isRequest() ? metadata.requestHeaders().formatter() + : metadata.responseHeaders().formatter(); + + auto header_writer = [&header_buffer, + formatter](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { + const auto header_key = header.key().getStringView(); + + writeVarString(header_buffer, formatter ? formatter->format(header_key) : header_key); + writeVarString(header_buffer, header.value().getStringView()); + return Http::HeaderMap::Iterate::Continue; + }; + if (metadata.isRequest()) { - metadata.requestHeaders().iterate( - [&header_buffer](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { - writeVarString(header_buffer, header.key().getStringView()); - writeVarString(header_buffer, header.value().getStringView()); - return Http::HeaderMap::Iterate::Continue; - }); + metadata.requestHeaders().iterate(header_writer); } else { - metadata.responseHeaders().iterate( - [&header_buffer](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate { - writeVarString(header_buffer, header.key().getStringView()); - writeVarString(header_buffer, header.value().getStringView()); - return Http::HeaderMap::Iterate::Continue; - }); + metadata.responseHeaders().iterate(header_writer); } } diff --git a/source/extensions/filters/network/thrift_proxy/metadata.h b/source/extensions/filters/network/thrift_proxy/metadata.h index cc013ed30dafa..001fd713fb661 100644 --- a/source/extensions/filters/network/thrift_proxy/metadata.h +++ b/source/extensions/filters/network/thrift_proxy/metadata.h @@ -7,18 +7,46 @@ #include #include "envoy/buffer/buffer.h" +#include "envoy/http/header_formatter.h" #include "source/common/common/macros.h" #include "source/common/http/header_map_impl.h" #include "source/extensions/filters/network/thrift_proxy/thrift.h" #include "source/extensions/filters/network/thrift_proxy/tracing.h" +#include "absl/strings/str_replace.h" #include "absl/types/optional.h" namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +namespace { + +// See: https://github.com/apache/thrift/commit/e165fa3c85d00cb984f4d9635ed60909a1266ce1 +class ThriftCaseHeaderFormatter : public Envoy::Http::StatefulHeaderKeyFormatter { +public: + ThriftCaseHeaderFormatter() = default; + + // Envoy::Http::StatefulHeaderKeyFormatter + std::string format(absl::string_view key) const override { + const auto remembered_key_itr = original_header_keys_.find(key); + return remembered_key_itr != original_header_keys_.end() ? remembered_key_itr->second + : std::string(key); + } + void processKey(absl::string_view key) override { + std::string s = absl::StrReplaceAll(key, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}}); + std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); }); + original_header_keys_.try_emplace(std::move(s), std::string(key)); + } + void setReasonPhrase(absl::string_view) override {} + absl::string_view getReasonPhrase() const override { return ""; } + +private: + absl::flat_hash_map original_header_keys_; +}; + +} // namespace /** * MessageMetadata encapsulates metadata about Thrift messages. The various fields are considered @@ -28,11 +56,19 @@ namespace ThriftProxy { */ class MessageMetadata { public: - MessageMetadata(bool is_request = true) : is_request_(is_request) { + MessageMetadata(bool is_request = true, bool preserve_keys = false) : is_request_(is_request) { if (is_request) { - request_headers_ = Http::RequestHeaderMapImpl::create(); + auto request_headers = Http::RequestHeaderMapImpl::create(); + if (preserve_keys) { + request_headers->setFormatter(std::make_unique()); + } + request_headers_ = std::move(request_headers); } else { - response_headers_ = Http::ResponseHeaderMapImpl::create(); + auto response_headers = Http::ResponseHeaderMapImpl::create(); + if (preserve_keys) { + response_headers->setFormatter(std::make_unique()); + } + response_headers_ = std::move(response_headers); } } diff --git a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h index dc241fb624890..bc02d856cab7b 100644 --- a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h @@ -76,6 +76,7 @@ struct NullResponseDecoder : public DecoderCallbacks, public ProtocolConverter { DecoderEventHandler& newDecoderEventHandler() override { return *this; } bool passthroughEnabled() const override { return true; } bool isRequest() const override { return false; } + bool headerKeysPreserveCase() const override { return false; } DecoderPtr decoder_; Buffer::OwnedImpl response_buffer_; diff --git a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h index 24c43c85565ea..b1bd5b8c949e7 100644 --- a/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h +++ b/source/extensions/filters/network/thrift_proxy/thrift_object_impl.h @@ -249,6 +249,7 @@ class ThriftObjectImpl : public ThriftObject, } bool passthroughEnabled() const override { return false; } bool isRequest() const override { return false; } + bool headerKeysPreserveCase() const override { return false; } // ThriftObject bool onData(Buffer::Instance& buffer) override; diff --git a/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc index 0c08e4c4a1c9b..3a7bcebd7ce87 100644 --- a/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc @@ -436,10 +436,11 @@ TEST(HeaderTransportTest, InvalidInfoBlock) { } } -TEST(HeaderTransportTest, InfoBlock) { +MessageMetadata testInfoBlock(bool preserve_keys, const std::string& key, + const std::string& value) { HeaderTransportImpl transport; Buffer::OwnedImpl buffer; - MessageMetadata metadata(true); + MessageMetadata metadata(true, preserve_keys); metadata.requestHeaders().addCopy(Http::LowerCaseString("not"), "empty"); @@ -449,10 +450,10 @@ TEST(HeaderTransportTest, InfoBlock) { buffer.writeBEInt(1); // sequence number buffer.writeBEInt(38); // size 152 addSeq(buffer, {0, 0, 1, 3}); // 0 = binary proto, 0 = num transforms, 1 = key value, 3 = num kvs - buffer.writeByte(3); - buffer.add("key"); - buffer.writeByte(5); - buffer.add("value"); + buffer.writeByte(key.size()); + buffer.add(key); + buffer.writeByte(value.size()); + buffer.add(value); buffer.writeByte(4); buffer.add("key2"); addSeq(buffer, {0x80, 0x01}); // var int 128 @@ -463,7 +464,9 @@ TEST(HeaderTransportTest, InfoBlock) { Http::TestRequestHeaderMapImpl expected_headers; expected_headers.addCopy(Http::LowerCaseString("not"), "empty"); - expected_headers.addCopy(Http::LowerCaseString("key"), "value"); + expected_headers.addCopy(Http::LowerCaseString(absl::StrReplaceAll( + key, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}})), + value); expected_headers.addCopy(Http::LowerCaseString("key2"), std::string(128, 'x')); expected_headers.addCopy(Http::LowerCaseString(""), ""); @@ -472,6 +475,43 @@ TEST(HeaderTransportTest, InfoBlock) { EXPECT_EQ(expected_headers, metadata.requestHeaders()); EXPECT_EQ(buffer.length(), 0); + + return metadata; +} + +TEST(HeaderTransportTest, InfoBlock) { testInfoBlock(false /* preserve-keys */, "key", "value"); } + +TEST(HeaderTransportTest, InfoBlockCaseSensitive) { + auto metadata = testInfoBlock(true /* preserve-keys */, "Key", "Value"); + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + Buffer::OwnedImpl msg; + msg.add("fake message"); + transport.encodeFrame(buffer, metadata, msg); + EXPECT_EQ(0, msg.length()); + EXPECT_EQ(std::string("\0\0\0\xBA\xF\xFF\0\0\0\0\0\x1\0)\0\0\x1\x4\x3not\x5" + "empty\x3Key\x5Value\x4key2\x80\x1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "xxxxxxxxxxxxx\0\0\0\0\0fake message", + 190), + buffer.toString()); +} + +TEST(HeaderTransportTest, InfoBlockCaseSensitiveNewline) { + auto metadata = testInfoBlock(true /* preserve-keys */, "K\ny", "Value"); + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + Buffer::OwnedImpl msg; + msg.add("fake message"); + transport.encodeFrame(buffer, metadata, msg); + EXPECT_EQ(0, msg.length()); + EXPECT_EQ( + std::string("\0\0\0\xBA\xF\xFF\0\0\0\0\0\x1\0)\0\0\x1\x4\x3not\x5" + "empty\x3K\ny\x5Value\x4key2\x80\x1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx" + "xxxxxxxxxxxxx\0\0\0\0\0fake message", + 190), + buffer.toString()); } TEST(HeaderTransportTest, DecodeFrameEnd) { diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index c5a6fd393b943..3164ef6e2f227 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -137,6 +137,7 @@ class MockDecoderCallbacks : public DecoderCallbacks { MOCK_METHOD(DecoderEventHandler&, newDecoderEventHandler, ()); MOCK_METHOD(bool, passthroughEnabled, (), (const)); MOCK_METHOD(bool, isRequest, (), (const)); + MOCK_METHOD(bool, headerKeysPreserveCase, (), (const)); }; class MockDecoderEventHandler : public DecoderEventHandler {