diff --git a/changelogs/current.yaml b/changelogs/current.yaml index 84a79904ec457..2b0f7fa9893c1 100644 --- a/changelogs/current.yaml +++ b/changelogs/current.yaml @@ -147,6 +147,9 @@ new_features: - area: thrift change: | added support for preserving header keys. +- area: thrift + change: | + added support onLocalReply to inform filters of local replies. - area: thrift change: | introduced thrift configurable encoder and bidirectional filters, which allows peeking and modifying the thrift response message. diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index b6f079f32b0dc..5602a77bf91dd 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -123,6 +123,7 @@ void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectRe read_callbacks_->connection().write(response_buffer, end_stream); } + if (end_stream) { read_callbacks_->connection().close(Network::ConnectionCloseType::FlushWrite); } @@ -947,8 +948,20 @@ Router::RouteConstSharedPtr ConnectionManager::ActiveRpc::route() { return cached_route_.value(); } +void ConnectionManager::ActiveRpc::onLocalReply(const MessageMetadata& metadata, bool end_stream) { + under_on_local_reply_ = true; + for (auto& filter : base_filters_) { + filter->onLocalReply(metadata, end_stream); + } + under_on_local_reply_ = false; +} + void ConnectionManager::ActiveRpc::sendLocalReply(const DirectResponse& response, bool end_stream) { + ASSERT(!under_on_local_reply_); metadata_->setSequenceId(original_sequence_id_); + + onLocalReply(*metadata_, end_stream); + parent_.sendLocalReply(*metadata_, response, end_stream); if (end_stream) { diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index fe4c5b3a1c177..e0d98f1528005 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -210,7 +210,8 @@ class ConnectionManager : public Network::ReadFilter, stream_id_(parent_.random_generator_.random()), stream_info_(parent_.time_source_, parent_.read_callbacks_->connection().connectionInfoProviderSharedPtr()), - local_response_sent_{false}, pending_transport_end_{false}, passthrough_{false} { + local_response_sent_{false}, pending_transport_end_{false}, passthrough_{false}, + under_on_local_reply_{false} { parent_.stats_.request_active_.inc(); } ~ActiveRpc() override { @@ -269,6 +270,7 @@ class ConnectionManager : public Network::ReadFilter, ProtocolType downstreamProtocolType() const override { return parent_.decoder_->protocolType(); } + void onLocalReply(const MessageMetadata& metadata, bool end_stream); void sendLocalReply(const DirectResponse& response, bool end_stream) override; void startUpstreamResponse(Transport& transport, Protocol& protocol) override; ThriftFilters::ResponseStatus upstreamData(Buffer::Instance& buffer) override; @@ -353,6 +355,7 @@ class ConnectionManager : public Network::ReadFilter, bool local_response_sent_ : 1; bool pending_transport_end_ : 1; bool passthrough_ : 1; + bool under_on_local_reply_ : 1; }; using ActiveRpcPtr = std::unique_ptr; diff --git a/source/extensions/filters/network/thrift_proxy/filter_utils.h b/source/extensions/filters/network/thrift_proxy/filter_utils.h index b4e0c315c3366..cfe7290bce2c2 100644 --- a/source/extensions/filters/network/thrift_proxy/filter_utils.h +++ b/source/extensions/filters/network/thrift_proxy/filter_utils.h @@ -17,6 +17,10 @@ class BidirectionalFilterWrapper final : public FilterBase { // ThriftBaseFilter void onDestroy() override { parent_->onDestroy(); } + LocalErrorStatus onLocalReply(const MessageMetadata& metadata, bool reset_imminent) override { + return parent_->onLocalReply(metadata, reset_imminent); + } + DecoderFilterSharedPtr decoder_filter_; EncoderFilterSharedPtr encoder_filter_; diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index a4a8d57be6763..200ce2b6bcd23 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -142,12 +142,39 @@ class EncoderFilterCallbacks : public virtual FilterCallbacks { virtual void continueEncoding() PURE; }; +/** + * Return codes for onLocalReply filter invocations. + */ +enum class LocalErrorStatus { + // Continue sending the local reply after onLocalError has been sent to all filters. + Continue, +}; + /** * Common interface for Thrift filters. */ class FilterBase { public: virtual ~FilterBase() = default; + + /** + * Called after sendLocalReply is called, and before any local reply is + * serialized either to filters, or downstream. + * This will be called on both encoder and decoder filters starting at the + * first filter and working towards the terminal filter configured (generally the router filter). + * + * Filters implementing onLocalReply are responsible for never calling sendLocalReply + * from onLocalReply, as that has the potential for looping. + * + * @param metadata response metadata. + * @param reset_imminent True if the downstream connection should be closed after this response + * @param LocalErrorStatus the action to take after onLocalError completes. + */ + virtual LocalErrorStatus onLocalReply([[maybe_unused]] const MessageMetadata& metadata, + [[maybe_unused]] bool end_stream) { + return LocalErrorStatus::Continue; + } + /** * This routine is called prior to a filter being destroyed. This may happen after normal stream * finish (both downstream and upstream) or due to reset. Every filter is responsible for making diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 3e6b0b645396a..b1a15383be038 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -1833,6 +1833,14 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) { buffer.add("response"); return DirectResponse::ResponseType::SuccessReply; })); + { + InSequence s; + EXPECT_CALL(*custom_decoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*decoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*custom_encoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*encoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*bidirectional_filter_, onLocalReply(_, _)); + } // First filter sends local reply. EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) @@ -1879,6 +1887,15 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) { return DirectResponse::ResponseType::ErrorReply; })); + { + InSequence s; + EXPECT_CALL(*custom_decoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*decoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*custom_encoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*encoder_filter_, onLocalReply(_, _)); + EXPECT_CALL(*bidirectional_filter_, onLocalReply(_, _)); + } + // First filter sends local reply. EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr) -> FilterStatus { diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 3164ef6e2f227..6d58bbb7aceb7 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -214,6 +214,8 @@ class MockDecoderFilter : public DecoderFilter { // ThriftProxy::ThriftFilters::DecoderFilter MOCK_METHOD(void, onDestroy, ()); + MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply, + (const MessageMetadata& metadata, bool reset_imminent)); MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks)); MOCK_METHOD(bool, passthroughSupported, (), (const)); @@ -280,6 +282,8 @@ class MockEncoderFilter : public EncoderFilter { // ThriftProxy::ThriftFilters::EncoderFilter MOCK_METHOD(void, onDestroy, ()); + MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply, + (const MessageMetadata& metadata, bool reset_imminent)); MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks)); MOCK_METHOD(bool, passthroughSupported, (), (const)); @@ -343,6 +347,8 @@ class MockBidirectionalFilter : public BidirectionalFilter { // ThriftProxy::ThriftFilters::BidirectionalFilter MOCK_METHOD(void, onDestroy, ()); + MOCK_METHOD(ThriftFilters::LocalErrorStatus, onLocalReply, + (const MessageMetadata& metadata, bool reset_imminent)); MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks)); MOCK_METHOD(bool, encodePassthroughSupported, (), (const)); MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks));