Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
35 commits
Select commit Hold shift + click to select a range
0f4025f
thrift: add support for case-sensitive header keys
Apr 19, 2022
75c55cc
Fix build
Apr 20, 2022
82654b8
Add missing method to mock
Apr 20, 2022
944cc90
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
Apr 20, 2022
d802775
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
Apr 20, 2022
7088873
Fix ASAN issue
Apr 21, 2022
26b27f3
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
Apr 21, 2022
64fb0d5
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
Apr 29, 2022
3cb7dfc
Use a formatter to preserve the case of http header keys.
Apr 29, 2022
5893146
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 2, 2022
3ab9fc4
Add TODO comment for moving code to common path
May 2, 2022
40fcc7e
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 10, 2022
c4d6567
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 10, 2022
be2c400
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 10, 2022
110fe1c
Preserve '\0', '\n', and '\r'
May 10, 2022
ef64016
Add changelog entry
May 10, 2022
c59921c
Link to Thrift spec
May 10, 2022
72075d2
Fix iter (set -> map)
May 10, 2022
b59a0a5
Missing header
May 10, 2022
2e7610e
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 10, 2022
dad295f
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 11, 2022
3f0fcdf
Add test for \n in header key
May 11, 2022
515abce
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 13, 2022
0158d54
@fishcakez's feedback.
May 13, 2022
0ef3783
Fix docs
May 16, 2022
f3cec5c
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 16, 2022
1f905d2
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 19, 2022
27ede7a
Update changelog comment
May 19, 2022
ece8b1e
encodeFrame: avoid header writer lambda duplicate
May 19, 2022
d37403d
Improve var naming
May 19, 2022
4a0c86c
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 23, 2022
d59db4b
setFormatter() only avail on impl types, not base classes
May 23, 2022
4ccddac
Fix
May 23, 2022
5823760
Merge remote-tracking branch 'upstream/main' into lowercase-thrift-he…
May 24, 2022
b518c2f
Fix
May 24, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ message Trds {
string route_config_name = 2;
}

// [#next-free-field: 10]
// [#next-free-field: 11]
message ThriftProxy {
option (udpa.annotations.versioning).previous_message_type =
"envoy.config.filter.network.thrift_proxy.v2alpha1.ThriftProxy";
Expand Down Expand Up @@ -117,6 +117,13 @@ message ThriftProxy {
// Configuration for :ref:`access logs <arch_overview_access_logs>`
// emitted by Thrift proxy.
repeated config.accesslog.v3.AccessLog access_log = 9;

// If set to true, Envoy will preserve the case of Thrift header keys instead of serializing them to
// lower case as per the default behavior. Note that NUL, CR and LF characters will also be preserved
// as mandated by the Thrift spec.
//
// More info: https://github.com/apache/thrift/commit/e165fa3c85d00cb984f4d9635ed60909a1266ce1.
bool header_keys_preserve_case = 10;
}

// ThriftFilter configures a Thrift filter.
Expand Down
3 changes: 3 additions & 0 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,9 @@ new_features:
- area: thrift
change: |
added support for access logging.
- area: thrift
change: |
added support for preserving header keys.
- area: thrift
change: |
introduced thrift configurable encoder and bidirectional filters, which allows peeking and modifying the thrift response message.
Expand Down
3 changes: 2 additions & 1 deletion source/extensions/filters/network/thrift_proxy/config.cc
Original file line number Diff line number Diff line change
Expand Up @@ -138,7 +138,8 @@ ConfigImpl::ConfigImpl(
stats_(ThriftFilterStats::generateStats(stats_prefix_, context_.scope())),
transport_(lookupTransport(config.transport())), proto_(lookupProtocol(config.protocol())),
payload_passthrough_(config.payload_passthrough()),
max_requests_per_connection_(config.max_requests_per_connection().value()) {
max_requests_per_connection_(config.max_requests_per_connection().value()),
header_keys_preserve_case_(config.header_keys_preserve_case()) {

if (config.thrift_filters().empty()) {
ENVOY_LOG(debug, "using default router filter");
Expand Down
2 changes: 2 additions & 0 deletions source/extensions/filters/network/thrift_proxy/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ class ConfigImpl : public Config,
const std::vector<AccessLog::InstanceSharedPtr>& accessLogs() const override {
return access_logs_;
}
bool headerKeysPreserveCase() const override { return header_keys_preserve_case_; }

private:
void processFilter(
Expand All @@ -108,6 +109,7 @@ class ConfigImpl : public Config,

const uint64_t max_requests_per_connection_{};
std::vector<AccessLog::InstanceSharedPtr> access_logs_;
const bool header_keys_preserve_case_;
};

} // namespace ThriftProxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -220,6 +220,8 @@ bool ConnectionManager::passthroughEnabled() const {
return (*rpcs_.begin())->passthroughSupported();
}

bool ConnectionManager::headerKeysPreserveCase() const { return config_.headerKeysPreserveCase(); }

bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) {
upstream_buffer_.move(data);

Expand Down Expand Up @@ -440,6 +442,10 @@ FilterStatus ConnectionManager::ResponseDecoder::setEnd() {
return parent_.applyEncoderFilters(DecoderEvent::SetEnd, absl::any(), protocol_converter_);
}

bool ConnectionManager::ResponseDecoder::headerKeysPreserveCase() const {
return parent_.parent_.headerKeysPreserveCase();
}

void ConnectionManager::ActiveRpcDecoderFilter::continueDecoding() {
const FilterStatus status =
parent_.applyDecoderFilters(DecoderEvent::ContinueDecode, absl::any(), this);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ class Config {
virtual bool payloadPassthrough() const PURE;
virtual uint64_t maxRequestsPerConnection() const PURE;
virtual const std::vector<AccessLog::InstanceSharedPtr>& accessLogs() const PURE;
virtual bool headerKeysPreserveCase() const PURE;
};

/**
Expand Down Expand Up @@ -71,6 +72,7 @@ class ConnectionManager : public Network::ReadFilter,
DecoderEventHandler& newDecoderEventHandler() override;
bool passthroughEnabled() const override;
bool isRequest() const override { return true; }
bool headerKeysPreserveCase() const override;

private:
struct ActiveRpc;
Expand Down Expand Up @@ -116,6 +118,7 @@ class ConnectionManager : public Network::ReadFilter,
DecoderEventHandler& newDecoderEventHandler() override { return *this; }
bool passthroughEnabled() const override;
bool isRequest() const override { return false; }
bool headerKeysPreserveCase() const override;

void finalizeResponse();

Expand Down
3 changes: 2 additions & 1 deletion source/extensions/filters/network/thrift_proxy/decoder.cc
Original file line number Diff line number Diff line change
Expand Up @@ -436,7 +436,8 @@ FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer_underflow) {
if (!frame_started_) {
// Look for start of next frame.
if (!metadata_) {
metadata_ = std::make_shared<MessageMetadata>(callbacks_.isRequest());
metadata_ = std::make_shared<MessageMetadata>(callbacks_.isRequest(),
callbacks_.headerKeysPreserveCase());
}

if (!transport_.decodeFrameStart(data, *metadata_)) {
Expand Down
5 changes: 5 additions & 0 deletions source/extensions/filters/network/thrift_proxy/decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,11 @@ class DecoderCallbacks {
* See https://github.com/apache/thrift/blob/master/lib/ts/thrift.d.ts#L68.
*/
virtual bool isRequest() const PURE;

/**
* @return True if payload header keys should be treated as case-sensitive.
*/
virtual bool headerKeysPreserveCase() const PURE;
};

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
#include <limits>

#include "envoy/common/exception.h"
#include "envoy/http/header_formatter.h"

#include "source/common/buffer/buffer_impl.h"
#include "source/extensions/filters/network/thrift_proxy/buffer_helper.h"
Expand Down Expand Up @@ -132,6 +133,8 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta
}

const bool is_request = metadata.isRequest();
auto formatter =
is_request ? metadata.requestHeaders().formatter() : metadata.responseHeaders().formatter();

while (header_size > 0) {
// Attempt to read info blocks
Expand All @@ -150,9 +153,13 @@ bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMeta

while (num_headers-- > 0) {
std::string key_string = drainVarString(buffer, header_size, "header key");
if (formatter) {
formatter->processKey(key_string);
}
// LowerCaseString doesn't allow '\0', '\n', and '\r'.
key_string =
absl::StrReplaceAll(key_string, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});

const Http::LowerCaseString key = Http::LowerCaseString(key_string);
const std::string value = drainVarString(buffer, header_size, "header value");

Expand Down Expand Up @@ -221,20 +228,22 @@ void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMet
// Num headers
BufferHelper::writeVarIntI32(header_buffer, static_cast<int32_t>(headers_size));

auto formatter = metadata.isRequest() ? metadata.requestHeaders().formatter()
: metadata.responseHeaders().formatter();

auto header_writer = [&header_buffer,
formatter](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
const auto header_key = header.key().getStringView();

writeVarString(header_buffer, formatter ? formatter->format(header_key) : header_key);
writeVarString(header_buffer, header.value().getStringView());
return Http::HeaderMap::Iterate::Continue;
};

if (metadata.isRequest()) {
metadata.requestHeaders().iterate(
[&header_buffer](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
writeVarString(header_buffer, header.key().getStringView());
writeVarString(header_buffer, header.value().getStringView());
return Http::HeaderMap::Iterate::Continue;
});
metadata.requestHeaders().iterate(header_writer);
} else {
metadata.responseHeaders().iterate(
[&header_buffer](const Http::HeaderEntry& header) -> Http::HeaderMap::Iterate {
writeVarString(header_buffer, header.key().getStringView());
writeVarString(header_buffer, header.value().getStringView());
return Http::HeaderMap::Iterate::Continue;
});
metadata.responseHeaders().iterate(header_writer);
}
}

Expand Down
42 changes: 39 additions & 3 deletions source/extensions/filters/network/thrift_proxy/metadata.h
Original file line number Diff line number Diff line change
Expand Up @@ -7,18 +7,46 @@
#include <string>

#include "envoy/buffer/buffer.h"
#include "envoy/http/header_formatter.h"

#include "source/common/common/macros.h"
#include "source/common/http/header_map_impl.h"
#include "source/extensions/filters/network/thrift_proxy/thrift.h"
#include "source/extensions/filters/network/thrift_proxy/tracing.h"

#include "absl/strings/str_replace.h"
#include "absl/types/optional.h"

namespace Envoy {
namespace Extensions {
namespace NetworkFilters {
namespace ThriftProxy {
namespace {

// See: https://github.com/apache/thrift/commit/e165fa3c85d00cb984f4d9635ed60909a1266ce1
class ThriftCaseHeaderFormatter : public Envoy::Http::StatefulHeaderKeyFormatter {
public:
ThriftCaseHeaderFormatter() = default;

// Envoy::Http::StatefulHeaderKeyFormatter
std::string format(absl::string_view key) const override {
const auto remembered_key_itr = original_header_keys_.find(key);
return remembered_key_itr != original_header_keys_.end() ? remembered_key_itr->second
: std::string(key);
}
void processKey(absl::string_view key) override {
std::string s = absl::StrReplaceAll(key, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}});
std::transform(s.begin(), s.end(), s.begin(), [](unsigned char c) { return std::tolower(c); });
original_header_keys_.try_emplace(std::move(s), std::string(key));
}
void setReasonPhrase(absl::string_view) override {}
absl::string_view getReasonPhrase() const override { return ""; }

private:
absl::flat_hash_map<std::string, std::string> original_header_keys_;
};

} // namespace

/**
* MessageMetadata encapsulates metadata about Thrift messages. The various fields are considered
Expand All @@ -28,11 +56,19 @@ namespace ThriftProxy {
*/
class MessageMetadata {
public:
MessageMetadata(bool is_request = true) : is_request_(is_request) {
MessageMetadata(bool is_request = true, bool preserve_keys = false) : is_request_(is_request) {
if (is_request) {
request_headers_ = Http::RequestHeaderMapImpl::create();
auto request_headers = Http::RequestHeaderMapImpl::create();
if (preserve_keys) {
request_headers->setFormatter(std::make_unique<ThriftCaseHeaderFormatter>());
}
request_headers_ = std::move(request_headers);
} else {
response_headers_ = Http::ResponseHeaderMapImpl::create();
auto response_headers = Http::ResponseHeaderMapImpl::create();
if (preserve_keys) {
response_headers->setFormatter(std::make_unique<ThriftCaseHeaderFormatter>());
}
response_headers_ = std::move(response_headers);
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,7 @@ struct NullResponseDecoder : public DecoderCallbacks, public ProtocolConverter {
DecoderEventHandler& newDecoderEventHandler() override { return *this; }
bool passthroughEnabled() const override { return true; }
bool isRequest() const override { return false; }
bool headerKeysPreserveCase() const override { return false; }

DecoderPtr decoder_;
Buffer::OwnedImpl response_buffer_;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,6 +249,7 @@ class ThriftObjectImpl : public ThriftObject,
}
bool passthroughEnabled() const override { return false; }
bool isRequest() const override { return false; }
bool headerKeysPreserveCase() const override { return false; }

// ThriftObject
bool onData(Buffer::Instance& buffer) override;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -436,10 +436,11 @@ TEST(HeaderTransportTest, InvalidInfoBlock) {
}
}

TEST(HeaderTransportTest, InfoBlock) {
MessageMetadata testInfoBlock(bool preserve_keys, const std::string& key,
const std::string& value) {
HeaderTransportImpl transport;
Buffer::OwnedImpl buffer;
MessageMetadata metadata(true);
MessageMetadata metadata(true, preserve_keys);

metadata.requestHeaders().addCopy(Http::LowerCaseString("not"), "empty");

Expand All @@ -449,10 +450,10 @@ TEST(HeaderTransportTest, InfoBlock) {
buffer.writeBEInt<int32_t>(1); // sequence number
buffer.writeBEInt<int16_t>(38); // size 152
addSeq(buffer, {0, 0, 1, 3}); // 0 = binary proto, 0 = num transforms, 1 = key value, 3 = num kvs
buffer.writeByte(3);
buffer.add("key");
buffer.writeByte(5);
buffer.add("value");
buffer.writeByte(key.size());
buffer.add(key);
buffer.writeByte(value.size());
buffer.add(value);
buffer.writeByte(4);
buffer.add("key2");
addSeq(buffer, {0x80, 0x01}); // var int 128
Expand All @@ -463,7 +464,9 @@ TEST(HeaderTransportTest, InfoBlock) {

Http::TestRequestHeaderMapImpl expected_headers;
expected_headers.addCopy(Http::LowerCaseString("not"), "empty");
expected_headers.addCopy(Http::LowerCaseString("key"), "value");
expected_headers.addCopy(Http::LowerCaseString(absl::StrReplaceAll(
key, {{std::string(1, '\0'), ""}, {"\n", ""}, {"\r", ""}})),
value);
expected_headers.addCopy(Http::LowerCaseString("key2"), std::string(128, 'x'));
expected_headers.addCopy(Http::LowerCaseString(""), "");

Expand All @@ -472,6 +475,43 @@ TEST(HeaderTransportTest, InfoBlock) {

EXPECT_EQ(expected_headers, metadata.requestHeaders());
EXPECT_EQ(buffer.length(), 0);

return metadata;
}

TEST(HeaderTransportTest, InfoBlock) { testInfoBlock(false /* preserve-keys */, "key", "value"); }

TEST(HeaderTransportTest, InfoBlockCaseSensitive) {
auto metadata = testInfoBlock(true /* preserve-keys */, "Key", "Value");
HeaderTransportImpl transport;
Buffer::OwnedImpl buffer;
Buffer::OwnedImpl msg;
msg.add("fake message");
transport.encodeFrame(buffer, metadata, msg);
EXPECT_EQ(0, msg.length());
EXPECT_EQ(std::string("\0\0\0\xBA\xF\xFF\0\0\0\0\0\x1\0)\0\0\x1\x4\x3not\x5"
"empty\x3Key\x5Value\x4key2\x80\x1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
"xxxxxxxxxxxxx\0\0\0\0\0fake message",
190),
buffer.toString());
}

TEST(HeaderTransportTest, InfoBlockCaseSensitiveNewline) {
auto metadata = testInfoBlock(true /* preserve-keys */, "K\ny", "Value");
HeaderTransportImpl transport;
Buffer::OwnedImpl buffer;
Buffer::OwnedImpl msg;
msg.add("fake message");
transport.encodeFrame(buffer, metadata, msg);
EXPECT_EQ(0, msg.length());
EXPECT_EQ(
std::string("\0\0\0\xBA\xF\xFF\0\0\0\0\0\x1\0)\0\0\x1\x4\x3not\x5"
"empty\x3K\ny\x5Value\x4key2\x80\x1xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
"xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx"
"xxxxxxxxxxxxx\0\0\0\0\0fake message",
190),
buffer.toString());
}

TEST(HeaderTransportTest, DecodeFrameEnd) {
Expand Down
1 change: 1 addition & 0 deletions test/extensions/filters/network/thrift_proxy/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ class MockDecoderCallbacks : public DecoderCallbacks {
MOCK_METHOD(DecoderEventHandler&, newDecoderEventHandler, ());
MOCK_METHOD(bool, passthroughEnabled, (), (const));
MOCK_METHOD(bool, isRequest, (), (const));
MOCK_METHOD(bool, headerKeysPreserveCase, (), (const));
};

class MockDecoderEventHandler : public DecoderEventHandler {
Expand Down