Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions changelogs/current.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,9 @@ new_features:
- area: thrift
change: |
added support for preserving header keys.
- area: thrift
change: |
added support onLocalReply to inform filters of local replies.
- area: thrift
change: |
introduced thrift configurable encoder and bidirectional filters, which allows peeking and modifying the thrift response message.
Expand Down
13 changes: 13 additions & 0 deletions source/extensions/filters/network/thrift_proxy/conn_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,7 @@ void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectRe

read_callbacks_->connection().write(response_buffer, end_stream);
}

if (end_stream) {
read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite);
}
Expand Down Expand Up @@ -947,8 +948,20 @@ Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() {
return cached_route_.value();
}

void ConnectionManager::ActiveRpc::onLocalReply(const MessageMetadata& metadata, bool end_stream) {
under_on_local_reply_ = true;
for (auto& filter : base_filters_) {
filter->onLocalReply(metadata, end_stream);
}
under_on_local_reply_ = false;
}

void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response, bool end_stream) {
ASSERT(!under_on_local_reply_);
metadata_->setSequenceId(original_sequence_id_);

onLocalReply(*metadata_, end_stream);

parent_.sendLocalReply(*metadata_, response, end_stream);

if (end_stream) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,8 @@ class ConnectionManager : public Network::ReadFilter,
stream_id_(parent_.random_generator_.random()),
stream_info_(parent_.time_source_,
parent_.read_callbacks_->connection().connectionInfoProviderSharedPtr()),
local_response_sent_{false}, pending_transport_end_{false}, passthrough_{false} {
local_response_sent_{false}, pending_transport_end_{false}, passthrough_{false},
under_on_local_reply_{false} {
parent_.stats_.request_active_.inc();
}
~ActiveRpc() override {
Expand Down Expand Up @@ -269,6 +270,7 @@ class ConnectionManager : public Network::ReadFilter,
ProtocolType downstreamProtocolType() const override {
return parent_.decoder_->protocolType();
}
void onLocalReply(const MessageMetadata& metadata, bool end_stream);
void sendLocalReply(const DirectResponse& response, bool end_stream) override;
void startUpstreamResponse(Transport& transport, Protocol& protocol) override;
ThriftFilters::ResponseStatus upstreamData(Buffer::Instance& buffer) override;
Expand Down Expand Up @@ -353,6 +355,7 @@ class ConnectionManager : public Network::ReadFilter,
bool local_response_sent_ : 1;
bool pending_transport_end_ : 1;
bool passthrough_ : 1;
bool under_on_local_reply_ : 1;
};

using ActiveRpcPtr = std::unique_ptr<ActiveRpc>;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,10 @@ class BidirectionalFilterWrapper final : public FilterBase {
// ThriftBaseFilter
void onDestroy() override { parent_->onDestroy(); }

LocalErrorStatus onLocalReply(const MessageMetadata& metadata, bool reset_imminent) override {
return parent_->onLocalReply(metadata, reset_imminent);
}

DecoderFilterSharedPtr decoder_filter_;
EncoderFilterSharedPtr encoder_filter_;

Expand Down
27 changes: 27 additions & 0 deletions source/extensions/filters/network/thrift_proxy/filters/filter.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,12 +142,39 @@ class EncoderFilterCallbacks : public virtual FilterCallbacks {
virtual void continueEncoding() PURE;
};

/**
* Return codes for onLocalReply filter invocations.
*/
enum class LocalErrorStatus {
// Continue sending the local reply after onLocalError has been sent to all filters.
Continue,
};

/**
* Common interface for Thrift filters.
*/
class FilterBase {
public:
virtual ~FilterBase() = default;

/**
* Called after sendLocalReply is called, and before any local reply is
* serialized either to filters, or downstream.
* This will be called on both encoder and decoder filters starting at the
* first filter and working towards the terminal filter configured (generally the router filter).
*
* Filters implementing onLocalReply are responsible for never calling sendLocalReply
* from onLocalReply, as that has the potential for looping.
*
* @param metadata response metadata.
* @param reset_imminent True if the downstream connection should be closed after this response
* @param LocalErrorStatus the action to take after onLocalError completes.
*/
virtual LocalErrorStatus onLocalReply([[maybe_unused]] const MessageMetadata& metadata,
[[maybe_unused]] bool end_stream) {
return LocalErrorStatus::Continue;
}

/**
* This routine is called prior to a filter being destroyed. This may happen after normal stream
* finish (both downstream and upstream) or due to reset. Every filter is responsible for making
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1833,6 +1833,14 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) {
buffer.add("response");
return DirectResponse::ResponseType::SuccessReply;
}));
{
InSequence s;
EXPECT_CALL(*custom_decoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*decoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*custom_encoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*encoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*bidirectional_filter_, onLocalReply(_, _));
}

// First filter sends local reply.
EXPECT_CALL(*custom_decoder_filter_, messageBegin(_))
Expand Down Expand Up @@ -1879,6 +1887,15 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) {
return DirectResponse::ResponseType::ErrorReply;
}));

{
InSequence s;
EXPECT_CALL(*custom_decoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*decoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*custom_encoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*encoder_filter_, onLocalReply(_, _));
EXPECT_CALL(*bidirectional_filter_, onLocalReply(_, _));
}

// First filter sends local reply.
EXPECT_CALL(*custom_decoder_filter_, messageBegin(_))
.WillOnce(Invoke([&](MessageMetadataSharedPtr) -> FilterStatus {
Expand Down
6 changes: 6 additions & 0 deletions test/extensions/filters/network/thrift_proxy/mocks.h
Original file line number Diff line number Diff line change
Expand Up @@ -214,6 +214,8 @@ class MockDecoderFilter : public DecoderFilter {

// ThriftProxy::ThriftFilters::DecoderFilter
MOCK_METHOD(void, onDestroy, ());
MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply,
(const MessageMetadata& metadata, bool reset_imminent));
MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks));
MOCK_METHOD(bool, passthroughSupported, (), (const));

Expand Down Expand Up @@ -280,6 +282,8 @@ class MockEncoderFilter : public EncoderFilter {

// ThriftProxy::ThriftFilters::EncoderFilter
MOCK_METHOD(void, onDestroy, ());
MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply,
(const MessageMetadata& metadata, bool reset_imminent));
MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks));
MOCK_METHOD(bool, passthroughSupported, (), (const));

Expand Down Expand Up @@ -343,6 +347,8 @@ class MockBidirectionalFilter : public BidirectionalFilter {

// ThriftProxy::ThriftFilters::BidirectionalFilter
MOCK_METHOD(void, onDestroy, ());
MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply,
(const MessageMetadata& metadata, bool reset_imminent));
MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks));
MOCK_METHOD(bool, encodePassthroughSupported, (), (const));
MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks));
Expand Down