diff --git a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc index 3a77ebbf78097..abd51e1da50f7 100644 --- a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.cc @@ -79,14 +79,29 @@ void ShadowRouterImpl::flushPendingCallbacks() { pending_callbacks_.clear(); } +FilterStatus ShadowRouterImpl::runOrSave(std::function&& cb, + const std::function& on_save) { + if (requestStarted()) { + return cb(); + } + + pending_callbacks_.push_back(std::move(cb)); + + if (on_save) { + on_save(); + } + + return FilterStatus::Continue; +} + FilterStatus ShadowRouterImpl::passthroughData(Buffer::Instance& data) { if (requestStarted()) { return ProtocolConverter::passthroughData(data); } auto copied = std::make_shared(data); - auto cb = [copied = std::move(copied), this]() mutable { - ProtocolConverter::passthroughData(*copied); + auto cb = [copied = std::move(copied), this]() mutable -> FilterStatus { + return ProtocolConverter::passthroughData(*copied); }; pending_callbacks_.push_back(std::move(cb)); @@ -98,8 +113,8 @@ FilterStatus ShadowRouterImpl::structBegin(absl::string_view name) { return ProtocolConverter::structBegin(name); } - auto cb = [name_str = std::string(name), this]() { - ProtocolConverter::structBegin(absl::string_view(name_str)); + auto cb = [name_str = std::string(name), this]() -> FilterStatus { + return ProtocolConverter::structBegin(absl::string_view(name_str)); }; pending_callbacks_.push_back(std::move(cb)); @@ -107,14 +122,7 @@ FilterStatus ShadowRouterImpl::structBegin(absl::string_view name) { } FilterStatus ShadowRouterImpl::structEnd() { - if (requestStarted()) { - return ProtocolConverter::structEnd(); - } - - auto cb = [this]() { ProtocolConverter::structEnd(); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave([this]() -> FilterStatus { return ProtocolConverter::structEnd(); }); } FilterStatus ShadowRouterImpl::fieldBegin(absl::string_view name, FieldType& field_type, @@ -123,8 +131,8 @@ FilterStatus ShadowRouterImpl::fieldBegin(absl::string_view name, FieldType& fie return ProtocolConverter::fieldBegin(name, field_type, field_id); } - auto cb = [name_str = std::string(name), field_type, field_id, this]() mutable { - ProtocolConverter::fieldBegin(absl::string_view(name_str), field_type, field_id); + auto cb = [name_str = std::string(name), field_type, field_id, this]() mutable -> FilterStatus { + return ProtocolConverter::fieldBegin(absl::string_view(name_str), field_type, field_id); }; pending_callbacks_.push_back(std::move(cb)); @@ -132,80 +140,37 @@ FilterStatus ShadowRouterImpl::fieldBegin(absl::string_view name, FieldType& fie } FilterStatus ShadowRouterImpl::fieldEnd() { - if (requestStarted()) { - return ProtocolConverter::fieldEnd(); - } - - auto cb = [this]() { ProtocolConverter::fieldEnd(); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave([this]() -> FilterStatus { return ProtocolConverter::fieldEnd(); }); } FilterStatus ShadowRouterImpl::boolValue(bool& value) { - if (requestStarted()) { - return ProtocolConverter::boolValue(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::boolValue(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::boolValue(value); }); } FilterStatus ShadowRouterImpl::byteValue(uint8_t& value) { - if (requestStarted()) { - return ProtocolConverter::byteValue(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::byteValue(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::byteValue(value); }); } FilterStatus ShadowRouterImpl::int16Value(int16_t& value) { - if (requestStarted()) { - return ProtocolConverter::int16Value(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::int16Value(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::int16Value(value); }); } FilterStatus ShadowRouterImpl::int32Value(int32_t& value) { - if (requestStarted()) { - return ProtocolConverter::int32Value(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::int32Value(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::int32Value(value); }); } FilterStatus ShadowRouterImpl::int64Value(int64_t& value) { - if (requestStarted()) { - return ProtocolConverter::int64Value(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::int64Value(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::int64Value(value); }); } FilterStatus ShadowRouterImpl::doubleValue(double& value) { - if (requestStarted()) { - return ProtocolConverter::doubleValue(value); - } - - auto cb = [value, this]() mutable { ProtocolConverter::doubleValue(value); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave( + [value, this]() mutable -> FilterStatus { return ProtocolConverter::doubleValue(value); }); } FilterStatus ShadowRouterImpl::stringValue(absl::string_view value) { @@ -213,8 +178,8 @@ FilterStatus ShadowRouterImpl::stringValue(absl::string_view value) { return ProtocolConverter::stringValue(value); } - auto cb = [value_str = std::string(value), this]() { - ProtocolConverter::stringValue(absl::string_view(value_str)); + auto cb = [value_str = std::string(value), this]() -> FilterStatus { + return ProtocolConverter::stringValue(absl::string_view(value_str)); }; pending_callbacks_.push_back(std::move(cb)); @@ -223,75 +188,37 @@ FilterStatus ShadowRouterImpl::stringValue(absl::string_view value) { FilterStatus ShadowRouterImpl::mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) { - if (requestStarted()) { + return runOrSave([key_type, value_type, size, this]() mutable -> FilterStatus { return ProtocolConverter::mapBegin(key_type, value_type, size); - } - - auto cb = [key_type, value_type, size, this]() mutable { - ProtocolConverter::mapBegin(key_type, value_type, size); - }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + }); } FilterStatus ShadowRouterImpl::mapEnd() { - if (requestStarted()) { - return ProtocolConverter::mapEnd(); - } - - auto cb = [this]() { ProtocolConverter::mapEnd(); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave([this]() -> FilterStatus { return ProtocolConverter::mapEnd(); }); } FilterStatus ShadowRouterImpl::listBegin(FieldType& elem_type, uint32_t& size) { - if (requestStarted()) { + return runOrSave([elem_type, size, this]() mutable -> FilterStatus { return ProtocolConverter::listBegin(elem_type, size); - } - - auto cb = [elem_type, size, this]() mutable { ProtocolConverter::listBegin(elem_type, size); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + }); } FilterStatus ShadowRouterImpl::listEnd() { - if (requestStarted()) { - return ProtocolConverter::listEnd(); - } - - auto cb = [this]() { ProtocolConverter::listEnd(); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave([this]() -> FilterStatus { return ProtocolConverter::listEnd(); }); } FilterStatus ShadowRouterImpl::setBegin(FieldType& elem_type, uint32_t& size) { - if (requestStarted()) { + return runOrSave([elem_type, size, this]() mutable -> FilterStatus { return ProtocolConverter::setBegin(elem_type, size); - } - - auto cb = [elem_type, size, this]() mutable { ProtocolConverter::setBegin(elem_type, size); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + }); } FilterStatus ShadowRouterImpl::setEnd() { - if (requestStarted()) { - return ProtocolConverter::setEnd(); - } - - auto cb = [this]() { ProtocolConverter::setEnd(); }; - pending_callbacks_.push_back(std::move(cb)); - - return FilterStatus::Continue; + return runOrSave([this]() -> FilterStatus { return ProtocolConverter::setEnd(); }); } FilterStatus ShadowRouterImpl::messageEnd() { - auto cb = [this]() { + auto cb = [this]() -> FilterStatus { ASSERT(upstream_request_->conn_data_ != nullptr); ProtocolConverter::messageEnd(); @@ -304,16 +231,11 @@ FilterStatus ShadowRouterImpl::messageEnd() { if (metadata_->messageType() == MessageType::Oneway) { upstream_request_->releaseConnection(false); } - }; - if (requestStarted()) { - cb(); - } else { - request_ready_ = true; - pending_callbacks_.push_back(std::move(cb)); - } + return FilterStatus::Continue; + }; - return FilterStatus::Continue; + return runOrSave(std::move(cb), [this]() -> void { request_ready_ = true; }); } bool ShadowRouterImpl::requestInProgress() { diff --git a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h index 7c14409db18c3..dc43592b58cf9 100644 --- a/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h +++ b/source/extensions/filters/network/thrift_proxy/router/shadow_writer_impl.h @@ -199,11 +199,14 @@ class ShadowRouterImpl : public ShadowRouterHandle, private: friend class ShadowWriterTest; + using ConverterCallback = std::function; void writeRequest(); bool requestInProgress(); bool requestStarted() const; void flushPendingCallbacks(); + FilterStatus runOrSave(std::function&& cb, + const std::function& on_save = {}); ShadowWriterImpl& parent_; const std::string cluster_name_; @@ -222,7 +225,6 @@ class ShadowRouterImpl : public ShadowRouterHandle, uint64_t response_size_{}; bool request_ready_ : 1; - using ConverterCallback = std::function; std::list pending_callbacks_; bool removed_{}; bool deferred_deleting_{};