diff --git a/changelogs/1.23.0.yaml b/changelogs/1.23.0.yaml index c63c0b33b13c2..75da45284b68e 100644 --- a/changelogs/1.23.0.yaml +++ b/changelogs/1.23.0.yaml @@ -66,6 +66,9 @@ new_features: - area: thrift change: | added flag to router to control downstream local close. :ref:`close_downstream_on_upstream_error `. +- area: thrift + change: | + introduced thrift configurable encoder and bidirectional filters, which allows peeking and modifying the thrift response message. - area: on_demand change: | :ref:`OnDemand ` got extended to hold configuration for on-demand cluster discovery. A similar message for :ref:`per-route configuration ` is also added. diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index d3c1e0af8e17f..671a798277187 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -70,6 +70,7 @@ envoy_cc_library( deps = [ ":app_exception_lib", ":decoder_lib", + ":filter_utils_lib", ":protocol_converter_lib", ":protocol_interface", ":stats_lib", @@ -115,6 +116,15 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "filter_utils_lib", + srcs = ["filter_utils.cc"], + hdrs = ["filter_utils.h"], + deps = [ + "//source/extensions/filters/network/thrift_proxy/filters:filter_interface", + ], +) + envoy_cc_library( name = "metadata_lib", hdrs = ["metadata.h"], diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.cc b/source/extensions/filters/network/thrift_proxy/conn_manager.cc index a64cf63e91d85..d06b305a7f768 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.cc +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.cc @@ -191,6 +191,10 @@ DecoderEventHandler& ConnectionManager::newDecoderEventHandler() { return **rpcs_.begin(); } +bool ConnectionManager::ResponseDecoder::passthroughEnabled() const { + return parent_.parent_.passthroughEnabled(); +} + bool ConnectionManager::passthroughEnabled() const { if (!config_.payloadPassthrough()) { return false; @@ -217,50 +221,29 @@ bool ConnectionManager::ResponseDecoder::onData(Buffer::Instance& data) { return complete_; } -FilterStatus ConnectionManager::ResponseDecoder::passthroughData(Buffer::Instance& data) { - passthrough_ = true; - return ProtocolConverter::passthroughData(data); +FilterStatus ConnectionManager::ResponseDecoder::transportBegin(MessageMetadataSharedPtr metadata) { + return parent_.applyEncoderFilters(DecoderEvent::TransportBegin, metadata, protocol_converter_); } -FilterStatus ConnectionManager::ResponseDecoder::messageBegin(MessageMetadataSharedPtr metadata) { - metadata_ = metadata; - metadata_->setSequenceId(parent_.original_sequence_id_); +FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { + ASSERT(metadata_ != nullptr); - if (metadata->hasReplyType()) { - success_ = metadata->replyType() == ReplyType::Success; + FilterStatus status = + parent_.applyEncoderFilters(DecoderEvent::TransportEnd, absl::any(), protocol_converter_); + // Currently we don't support returning FilterStatus::StopIteration from encoder filters. + // Hence, this if-statement is always false. + ASSERT(status == FilterStatus::Continue); + if (status == FilterStatus::StopIteration) { + pending_transport_end_ = true; + return FilterStatus::StopIteration; } - // Check if the upstream host is draining. - // - // Note: the drain header needs to be checked here in messageBegin, and not transportBegin, so - // that we can support the header in TTwitter protocol, which reads/adds response headers to - // metadata in messageBegin when reading the response from upstream. Therefore detecting a drain - // should happen here. - if (Runtime::runtimeFeatureEnabled("envoy.reloadable_features.thrift_connection_draining")) { - metadata_->setDraining(!metadata->headers().get(Headers::get().Drain).empty()); - metadata->headers().remove(Headers::get().Drain); - - // Check if this host itself is draining. - // - // Note: Similarly as above, the response is buffered until transportEnd. Therefore metadata - // should be set before the encodeFrame() call. It should be set at or after the messageBegin - // call so that the header is added after all upstream headers passed, due to messageBegin - // possibly not getting headers in transportBegin. - ConnectionManager& cm = parent_.parent_; - if (cm.drain_decision_.drainClose()) { - // TODO(rgs1): should the key value contain something useful (e.g.: minutes til drain is - // over)? - metadata->headers().addReferenceKey(Headers::get().Drain, "true"); - parent_.parent_.stats_.downstream_response_drain_close_.inc(); - } - } - - return ProtocolConverter::messageBegin(metadata); + finalizeResponse(); + return FilterStatus::Continue; } -FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { - ASSERT(metadata_ != nullptr); - +void ConnectionManager::ResponseDecoder::finalizeResponse() { + pending_transport_end_ = false; ConnectionManager& cm = parent_.parent_; if (cm.read_callbacks_->connection().state() == Network::Connection::State::Closed) { @@ -308,16 +291,135 @@ FilterStatus ConnectionManager::ResponseDecoder::transportEnd() { cm.stats_.response_invalid_type_.inc(); break; } +} - return FilterStatus::Continue; +FilterStatus ConnectionManager::ResponseDecoder::passthroughData(Buffer::Instance& data) { + passthrough_ = true; + + return parent_.applyEncoderFilters(DecoderEvent::PassthroughData, &data, protocol_converter_); } -bool ConnectionManager::ResponseDecoder::passthroughEnabled() const { - return parent_.parent_.passthroughEnabled(); +FilterStatus ConnectionManager::ResponseDecoder::messageBegin(MessageMetadataSharedPtr metadata) { + metadata_ = metadata; + metadata_->setSequenceId(parent_.original_sequence_id_); + + if (metadata->hasReplyType()) { + // TODO(kuochunghsu): the status of success could be altered by filters + success_ = metadata->replyType() == ReplyType::Success; + } + + // Check if the upstream host is draining. + // + // Note: the drain header needs to be checked here in messageBegin, and not transportBegin, so + // that we can support the header in TTwitter protocol, which reads/adds response headers to + // metadata in messageBegin when reading the response from upstream. Therefore detecting a drain + // should happen here. + if (Runtime::runtimeFeatureEnabled("envoy.reloadable_features.thrift_connection_draining")) { + metadata_->setDraining(!metadata->headers().get(Headers::get().Drain).empty()); + metadata->headers().remove(Headers::get().Drain); + + // Check if this host itself is draining. + // + // Note: Similarly as above, the response is buffered until transportEnd. Therefore metadata + // should be set before the encodeFrame() call. It should be set at or after the messageBegin + // call so that the header is added after all upstream headers passed, due to messageBegin + // possibly not getting headers in transportBegin. + ConnectionManager& cm = parent_.parent_; + if (cm.drain_decision_.drainClose()) { + // TODO(rgs1): should the key value contain something useful (e.g.: minutes til drain is + // over)? + metadata->headers().addReferenceKey(Headers::get().Drain, "true"); + parent_.parent_.stats_.downstream_response_drain_close_.inc(); + } + } + + return parent_.applyEncoderFilters(DecoderEvent::MessageBegin, metadata, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::messageEnd() { + return parent_.applyEncoderFilters(DecoderEvent::MessageEnd, absl::any(), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::structBegin(absl::string_view name) { + return parent_.applyEncoderFilters(DecoderEvent::StructBegin, name, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::structEnd() { + return parent_.applyEncoderFilters(DecoderEvent::StructEnd, absl::any(), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::fieldBegin(absl::string_view name, + FieldType& field_type, + int16_t& field_id) { + return parent_.applyEncoderFilters(DecoderEvent::FieldBegin, + std::make_tuple(std::string(name), field_type, field_id), + protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::fieldEnd() { + return parent_.applyEncoderFilters(DecoderEvent::FieldEnd, absl::any(), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::boolValue(bool& value) { + return parent_.applyEncoderFilters(DecoderEvent::BoolValue, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::byteValue(uint8_t& value) { + return parent_.applyEncoderFilters(DecoderEvent::ByteValue, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::int16Value(int16_t& value) { + return parent_.applyEncoderFilters(DecoderEvent::Int16Value, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::int32Value(int32_t& value) { + return parent_.applyEncoderFilters(DecoderEvent::Int32Value, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::int64Value(int64_t& value) { + return parent_.applyEncoderFilters(DecoderEvent::Int64Value, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::doubleValue(double& value) { + return parent_.applyEncoderFilters(DecoderEvent::DoubleValue, value, protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::stringValue(absl::string_view value) { + return parent_.applyEncoderFilters(DecoderEvent::StringValue, std::string(value), + protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::mapBegin(FieldType& key_type, + FieldType& value_type, uint32_t& size) { + return parent_.applyEncoderFilters( + DecoderEvent::MapBegin, std::make_tuple(key_type, value_type, size), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::mapEnd() { + return parent_.applyEncoderFilters(DecoderEvent::MapEnd, absl::any(), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::listBegin(FieldType& elem_type, uint32_t& size) { + return parent_.applyEncoderFilters(DecoderEvent::ListBegin, std::make_tuple(elem_type, size), + protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::listEnd() { + return parent_.applyEncoderFilters(DecoderEvent::ListEnd, absl::any(), protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::setBegin(FieldType& elem_type, uint32_t& size) { + return parent_.applyEncoderFilters(DecoderEvent::SetBegin, std::make_tuple(elem_type, size), + protocol_converter_); +} + +FilterStatus ConnectionManager::ResponseDecoder::setEnd() { + return parent_.applyEncoderFilters(DecoderEvent::SetEnd, absl::any(), protocol_converter_); } void ConnectionManager::ActiveRpcDecoderFilter::continueDecoding() { - const FilterStatus status = parent_.applyDecoderFilters(this); + const FilterStatus status = + parent_.applyDecoderFilters(DecoderEvent::ContinueDecode, absl::any(), this); if (status == FilterStatus::Continue) { // All filters have been executed for the current decoder state. if (parent_.pending_transport_end_) { @@ -329,8 +431,15 @@ void ConnectionManager::ActiveRpcDecoderFilter::continueDecoding() { } } -FilterStatus ConnectionManager::ActiveRpc::applyDecoderFilters(ActiveRpcDecoderFilter* filter) { - ASSERT(filter_action_ != nullptr); +void ConnectionManager::ActiveRpcEncoderFilter::continueEncoding() { + // Not supported. + ASSERT(false); +} + +FilterStatus ConnectionManager::ActiveRpc::applyDecoderFilters(DecoderEvent state, absl::any data, + ActiveRpcDecoderFilter* filter) { + ASSERT(filter_action_ == nullptr || state == DecoderEvent::ContinueDecode); + prepareFilterAction(state, data); if (local_response_sent_) { filter_action_ = nullptr; @@ -341,14 +450,39 @@ FilterStatus ConnectionManager::ActiveRpc::applyDecoderFilters(ActiveRpcDecoderF if (upgrade_handler_) { // Divert events to the current protocol upgrade handler. const FilterStatus status = filter_action_(upgrade_handler_.get()); + filter_action_ = nullptr; filter_context_.reset(); return status; } - std::list::iterator entry = - !filter ? decoder_filters_.begin() : std::next(filter->entry()); - for (; entry != decoder_filters_.end(); entry++) { - const FilterStatus status = filter_action_((*entry)->handle_.get()); + return applyFilters(filter, decoder_filters_); +} + +FilterStatus +ConnectionManager::ActiveRpc::applyEncoderFilters(DecoderEvent state, absl::any data, + ProtocolConverterSharedPtr protocol_converter, + ActiveRpcEncoderFilter* filter) { + ASSERT(filter_action_ == nullptr || state == DecoderEvent::ContinueDecode); + prepareFilterAction(state, data); + + FilterStatus status = + applyFilters(filter, encoder_filters_, protocol_converter); + // FilterStatus::StopIteration is currently not supported. + ASSERT(status == FilterStatus::Continue); + + return status; +} + +template +FilterStatus +ConnectionManager::ActiveRpc::applyFilters(FilterType* filter, + std::list>& filter_list, + ProtocolConverterSharedPtr protocol_converter) { + + typename std::list>::iterator entry = + !filter ? filter_list.begin() : std::next(filter->entry()); + for (; entry != filter_list.end(); entry++) { + const FilterStatus status = filter_action_((*entry)->decodeEventHandler()); if (local_response_sent_) { // The filter called sendLocalReply but _did not_ close the connection. // We return FilterStatus::Continue irrespective of the current result, @@ -374,20 +508,165 @@ FilterStatus ConnectionManager::ActiveRpc::applyDecoderFilters(ActiveRpcDecoderF } } + // The protocol converter writes the data to a buffer for response. + if (protocol_converter) { + filter_action_(protocol_converter.get()); + } + filter_action_ = nullptr; filter_context_.reset(); return FilterStatus::Continue; } -FilterStatus ConnectionManager::ActiveRpc::transportBegin(MessageMetadataSharedPtr metadata) { - filter_context_ = metadata; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - MessageMetadataSharedPtr metadata = absl::any_cast(filter_context_); - return filter->transportBegin(metadata); - }; +void ConnectionManager::ActiveRpc::prepareFilterAction(DecoderEvent event, absl::any data) { + // DecoderEvent::ContinueDecode indicates we're handling previous filter action with the + // filter chain. Therefore, we should not reset filter_action_ and filter_context_. + if (event == DecoderEvent::ContinueDecode) { + return; + } + filter_context_ = data; + + switch (event) { + case DecoderEvent::TransportBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + MessageMetadataSharedPtr metadata = absl::any_cast(filter_context_); + return filter->transportBegin(metadata); + }; + break; + case DecoderEvent::TransportEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { + return filter->transportEnd(); + }; + break; + case DecoderEvent::PassthroughData: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + Buffer::Instance* data = absl::any_cast(filter_context_); + return filter->passthroughData(*data); + }; + break; + case DecoderEvent::MessageBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + MessageMetadataSharedPtr metadata = absl::any_cast(filter_context_); + return filter->messageBegin(metadata); + }; + break; + case DecoderEvent::MessageEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { + return filter->messageEnd(); + }; + break; + case DecoderEvent::StructBegin: - return applyDecoderFilters(nullptr); + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + absl::string_view name = absl::any_cast(filter_context_); + return filter->structBegin(name); + }; + break; + case DecoderEvent::StructEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { + return filter->structEnd(); + }; + break; + case DecoderEvent::FieldBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + std::tuple& t = + absl::any_cast&>(filter_context_); + std::string& name = std::get<0>(t); + FieldType& field_type = std::get<1>(t); + int16_t& field_id = std::get<2>(t); + return filter->fieldBegin(name, field_type, field_id); + }; + break; + case DecoderEvent::FieldEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->fieldEnd(); }; + break; + case DecoderEvent::BoolValue: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + bool& value = absl::any_cast(filter_context_); + return filter->boolValue(value); + }; + break; + case DecoderEvent::ByteValue: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + uint8_t& value = absl::any_cast(filter_context_); + return filter->byteValue(value); + }; + break; + case DecoderEvent::Int16Value: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + int16_t& value = absl::any_cast(filter_context_); + return filter->int16Value(value); + }; + break; + case DecoderEvent::Int32Value: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + int32_t& value = absl::any_cast(filter_context_); + return filter->int32Value(value); + }; + break; + case DecoderEvent::Int64Value: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + int64_t& value = absl::any_cast(filter_context_); + return filter->int64Value(value); + }; + break; + case DecoderEvent::DoubleValue: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + double& value = absl::any_cast(filter_context_); + return filter->doubleValue(value); + }; + break; + case DecoderEvent::StringValue: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + std::string& value = absl::any_cast(filter_context_); + return filter->stringValue(value); + }; + break; + case DecoderEvent::MapBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + std::tuple& t = + absl::any_cast&>(filter_context_); + FieldType& key_type = std::get<0>(t); + FieldType& value_type = std::get<1>(t); + uint32_t& size = std::get<2>(t); + return filter->mapBegin(key_type, value_type, size); + }; + break; + case DecoderEvent::MapEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->mapEnd(); }; + break; + case DecoderEvent::ListBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + std::tuple& t = + absl::any_cast&>(filter_context_); + FieldType& elem_type = std::get<0>(t); + uint32_t& size = std::get<1>(t); + return filter->listBegin(elem_type, size); + }; + break; + case DecoderEvent::ListEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->listEnd(); }; + break; + case DecoderEvent::SetBegin: + filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { + std::tuple& t = + absl::any_cast&>(filter_context_); + FieldType& elem_type = std::get<0>(t); + uint32_t& size = std::get<1>(t); + return filter->setBegin(elem_type, size); + }; + break; + case DecoderEvent::SetEnd: + filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->setEnd(); }; + break; + default: + PANIC_DUE_TO_CORRUPT_ENUM; + } +} + +FilterStatus ConnectionManager::ActiveRpc::transportBegin(MessageMetadataSharedPtr metadata) { + return applyDecoderFilters(DecoderEvent::TransportBegin, metadata); } FilterStatus ConnectionManager::ActiveRpc::transportEnd() { @@ -403,11 +682,7 @@ FilterStatus ConnectionManager::ActiveRpc::transportEnd() { sendLocalReply(*parent_.protocol_->upgradeResponse(*upgrade_handler_), false); } } else { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { - return filter->transportEnd(); - }; - - status = applyDecoderFilters(nullptr); + status = applyDecoderFilters(DecoderEvent::TransportEnd, absl::any()); if (status == FilterStatus::StopIteration) { pending_transport_end_ = true; return status; @@ -465,9 +740,17 @@ void ConnectionManager::ActiveRpc::finalizeRequest() { } } +// TODO(kuochunghsu): passthroughSupported for decoder/encoder filters with more flexibility. +// That is, supporting passthrough data for decoder filters if all decoder filters agree, +// and supporting passthrough data for encoder filters if all encoder filters agree. bool ConnectionManager::ActiveRpc::passthroughSupported() const { for (auto& entry : decoder_filters_) { - if (!entry->handle_->passthroughSupported()) { + if (!entry->decoder_handle_->passthroughSupported()) { + return false; + } + } + for (auto& entry : encoder_filters_) { + if (!entry->encoder_handle_->passthroughSupported()) { return false; } } @@ -476,13 +759,7 @@ bool ConnectionManager::ActiveRpc::passthroughSupported() const { FilterStatus ConnectionManager::ActiveRpc::passthroughData(Buffer::Instance& data) { passthrough_ = true; - filter_context_ = &data; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - Buffer::Instance* data = absl::any_cast(filter_context_); - return filter->passthroughData(*data); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::PassthroughData, &data); } FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr metadata) { @@ -502,184 +779,82 @@ FilterStatus ConnectionManager::ActiveRpc::messageBegin(MessageMetadataSharedPtr ASSERT(upgrade_handler_ != nullptr); } - filter_context_ = metadata; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - MessageMetadataSharedPtr metadata = absl::any_cast(filter_context_); - return filter->messageBegin(metadata); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::MessageBegin, metadata); } FilterStatus ConnectionManager::ActiveRpc::messageEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->messageEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::MessageEnd, absl::any()); } FilterStatus ConnectionManager::ActiveRpc::structBegin(absl::string_view name) { - filter_context_ = std::string(name); - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::string& name = absl::any_cast(filter_context_); - return filter->structBegin(name); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::StructBegin, name); } FilterStatus ConnectionManager::ActiveRpc::structEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->structEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::StructEnd, absl::any()); } FilterStatus ConnectionManager::ActiveRpc::fieldBegin(absl::string_view name, FieldType& field_type, int16_t& field_id) { - filter_context_ = - std::tuple(std::string(name), field_type, field_id); - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::tuple& t = - absl::any_cast&>(filter_context_); - std::string& name = std::get<0>(t); - FieldType& field_type = std::get<1>(t); - int16_t& field_id = std::get<2>(t); - return filter->fieldBegin(name, field_type, field_id); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::FieldBegin, + std::make_tuple(std::string(name), field_type, field_id)); } FilterStatus ConnectionManager::ActiveRpc::fieldEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->fieldEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::FieldEnd, absl::any()); } FilterStatus ConnectionManager::ActiveRpc::boolValue(bool& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - bool& value = absl::any_cast(filter_context_); - return filter->boolValue(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::BoolValue, value); } FilterStatus ConnectionManager::ActiveRpc::byteValue(uint8_t& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - uint8_t& value = absl::any_cast(filter_context_); - return filter->byteValue(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::ByteValue, value); } FilterStatus ConnectionManager::ActiveRpc::int16Value(int16_t& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - int16_t& value = absl::any_cast(filter_context_); - return filter->int16Value(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::Int16Value, value); } FilterStatus ConnectionManager::ActiveRpc::int32Value(int32_t& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - int32_t& value = absl::any_cast(filter_context_); - return filter->int32Value(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::Int32Value, value); } FilterStatus ConnectionManager::ActiveRpc::int64Value(int64_t& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - int64_t& value = absl::any_cast(filter_context_); - return filter->int64Value(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::Int64Value, value); } FilterStatus ConnectionManager::ActiveRpc::doubleValue(double& value) { - filter_context_ = value; - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - double& value = absl::any_cast(filter_context_); - return filter->doubleValue(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::DoubleValue, value); } FilterStatus ConnectionManager::ActiveRpc::stringValue(absl::string_view value) { - filter_context_ = std::string(value); - - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::string& value = absl::any_cast(filter_context_); - return filter->stringValue(value); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::StringValue, std::string(value)); } FilterStatus ConnectionManager::ActiveRpc::mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) { - filter_context_ = std::tuple(key_type, value_type, size); - - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::tuple& t = - absl::any_cast&>(filter_context_); - FieldType& key_type = std::get<0>(t); - FieldType& value_type = std::get<1>(t); - uint32_t& size = std::get<2>(t); - return filter->mapBegin(key_type, value_type, size); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::MapBegin, std::make_tuple(key_type, value_type, size)); } FilterStatus ConnectionManager::ActiveRpc::mapEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->mapEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::MapEnd, absl::any()); } FilterStatus ConnectionManager::ActiveRpc::listBegin(FieldType& value_type, uint32_t& size) { - filter_context_ = std::tuple(value_type, size); - - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::tuple& t = - absl::any_cast&>(filter_context_); - FieldType& value_type = std::get<0>(t); - uint32_t& size = std::get<1>(t); - return filter->listBegin(value_type, size); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::ListBegin, std::make_tuple(value_type, size)); } FilterStatus ConnectionManager::ActiveRpc::listEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->listEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::ListEnd, absl::any()); } FilterStatus ConnectionManager::ActiveRpc::setBegin(FieldType& value_type, uint32_t& size) { - filter_context_ = std::tuple(value_type, size); - - filter_action_ = [this](DecoderEventHandler* filter) -> FilterStatus { - std::tuple& t = - absl::any_cast&>(filter_context_); - FieldType& value_type = std::get<0>(t); - uint32_t& size = std::get<1>(t); - return filter->setBegin(value_type, size); - }; - - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::SetBegin, std::make_tuple(value_type, size)); } FilterStatus ConnectionManager::ActiveRpc::setEnd() { - filter_action_ = [](DecoderEventHandler* filter) -> FilterStatus { return filter->setEnd(); }; - return applyDecoderFilters(nullptr); + return applyDecoderFilters(DecoderEvent::SetEnd, absl::any()); } void ConnectionManager::ActiveRpc::createFilterChain() { diff --git a/source/extensions/filters/network/thrift_proxy/conn_manager.h b/source/extensions/filters/network/thrift_proxy/conn_manager.h index 5b4a541d265f0..d64a74244ed53 100644 --- a/source/extensions/filters/network/thrift_proxy/conn_manager.h +++ b/source/extensions/filters/network/thrift_proxy/conn_manager.h @@ -14,6 +14,7 @@ #include "source/common/stats/timespan_impl.h" #include "source/common/stream_info/stream_info_impl.h" #include "source/extensions/filters/network/thrift_proxy/decoder.h" +#include "source/extensions/filters/network/thrift_proxy/filter_utils.h" #include "source/extensions/filters/network/thrift_proxy/filters/filter.h" #include "source/extensions/filters/network/thrift_proxy/protocol.h" #include "source/extensions/filters/network/thrift_proxy/protocol_converter.h" @@ -72,50 +73,68 @@ class ConnectionManager : public Network::ReadFilter, private: struct ActiveRpc; - struct ResponseDecoder : public DecoderCallbacks, public ProtocolConverter { + struct ResponseDecoder : public DecoderCallbacks, public DecoderEventHandler { ResponseDecoder(ActiveRpc& parent, Transport& transport, Protocol& protocol) : parent_(parent), decoder_(std::make_unique(transport, protocol, *this)), - complete_(false), passthrough_{false} { - initProtocolConverter(*parent_.parent_.protocol_, parent_.response_buffer_); + protocol_converter_(std::make_shared()), complete_{false}, + passthrough_{false}, pending_transport_end_{false} { + ; + protocol_converter_->initProtocolConverter(*parent_.parent_.protocol_, + parent_.response_buffer_); } bool onData(Buffer::Instance& data); - // ProtocolConverter + // DecoderEventHandler + FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override; + FilterStatus transportEnd() override; FilterStatus passthroughData(Buffer::Instance& data) override; FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override; - FilterStatus transportBegin(MessageMetadataSharedPtr metadata) override { - UNREFERENCED_PARAMETER(metadata); - return FilterStatus::Continue; - } - FilterStatus transportEnd() override; + FilterStatus messageEnd() override; + FilterStatus structBegin(absl::string_view name) override; + FilterStatus structEnd() override; + FilterStatus fieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) override; + FilterStatus fieldEnd() override; + FilterStatus boolValue(bool& value) override; + FilterStatus byteValue(uint8_t& value) override; + FilterStatus int16Value(int16_t& value) override; + FilterStatus int32Value(int32_t& value) override; + FilterStatus int64Value(int64_t& value) override; + FilterStatus doubleValue(double& value) override; + FilterStatus stringValue(absl::string_view value) override; + FilterStatus mapBegin(FieldType& key_type, FieldType& value_type, uint32_t& size) override; + FilterStatus mapEnd() override; + FilterStatus listBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus listEnd() override; + FilterStatus setBegin(FieldType& elem_type, uint32_t& size) override; + FilterStatus setEnd() override; // DecoderCallbacks DecoderEventHandler& newDecoderEventHandler() override { return *this; } bool passthroughEnabled() const override; + void finalizeResponse(); + ActiveRpc& parent_; DecoderPtr decoder_; Buffer::OwnedImpl upstream_buffer_; MessageMetadataSharedPtr metadata_; + ProtocolConverterSharedPtr protocol_converter_; absl::optional success_; bool complete_ : 1; bool passthrough_ : 1; + bool pending_transport_end_ : 1; }; using ResponseDecoderPtr = std::unique_ptr; - // Wraps a DecoderFilter and acts as the DecoderFilterCallbacks for the filter, enabling filter - // chain continuation. - struct ActiveRpcDecoderFilter : public ThriftFilters::DecoderFilterCallbacks, - LinkedObject { - ActiveRpcDecoderFilter(ActiveRpc& parent, ThriftFilters::DecoderFilterSharedPtr filter) - : parent_(parent), handle_(filter) {} + struct ActiveRpcFilterBase : public virtual ThriftFilters::FilterCallbacks { + ActiveRpcFilterBase(ActiveRpc& parent) : parent_(parent) {} - // ThriftFilters::DecoderFilterCallbacks + // ThriftFilters::FilterCallbacks uint64_t streamId() const override { return parent_.stream_id_; } const Network::Connection* connection() const override { return parent_.connection(); } Event::Dispatcher& dispatcher() override { return parent_.dispatcher(); } - void continueDecoding() override; Router::RouteConstSharedPtr route() override { return parent_.route(); } TransportType downstreamTransportType() const override { return parent_.downstreamTransportType(); @@ -123,6 +142,25 @@ class ConnectionManager : public Network::ReadFilter, ProtocolType downstreamProtocolType() const override { return parent_.downstreamProtocolType(); } + + void resetDownstreamConnection() override { parent_.resetDownstreamConnection(); } + StreamInfo::StreamInfo& streamInfo() override { return parent_.streamInfo(); } + MessageMetadataSharedPtr responseMetadata() override { return parent_.responseMetadata(); } + bool responseSuccess() override { return parent_.responseSuccess(); } + void onReset() override { parent_.onReset(); } + + ActiveRpc& parent_; + }; + + // Wraps a DecoderFilter and acts as the DecoderFilterCallbacks for the filter, enabling filter + // chain continuation. + struct ActiveRpcDecoderFilter : public ActiveRpcFilterBase, + public virtual ThriftFilters::DecoderFilterCallbacks, + LinkedObject { + ActiveRpcDecoderFilter(ActiveRpc& parent, ThriftFilters::DecoderFilterSharedPtr filter) + : ActiveRpcFilterBase(parent), decoder_handle_(filter) {} + + // ThriftFilters::DecoderFilterCallbacks void sendLocalReply(const DirectResponse& response, bool end_stream) override { parent_.sendLocalReply(response, end_stream); } @@ -132,22 +170,33 @@ class ConnectionManager : public Network::ReadFilter, ThriftFilters::ResponseStatus upstreamData(Buffer::Instance& buffer) override { return parent_.upstreamData(buffer); } - void resetDownstreamConnection() override { parent_.resetDownstreamConnection(); } - StreamInfo::StreamInfo& streamInfo() override { return parent_.streamInfo(); } - MessageMetadataSharedPtr responseMetadata() override { return parent_.responseMetadata(); } - bool responseSuccess() override { return parent_.responseSuccess(); } - void onReset() override { parent_.onReset(); } - - ActiveRpc& parent_; - ThriftFilters::DecoderFilterSharedPtr handle_; + void continueDecoding() override; + DecoderEventHandler* decodeEventHandler() { return decoder_handle_.get(); } + ThriftFilters::DecoderFilterSharedPtr decoder_handle_; }; using ActiveRpcDecoderFilterPtr = std::unique_ptr; + // Wraps a EncoderFilter and acts as the EncoderFilterCallbacks for the filter, enabling filter + // chain continuation. + struct ActiveRpcEncoderFilter : public ActiveRpcFilterBase, + public virtual ThriftFilters::EncoderFilterCallbacks, + LinkedObject { + ActiveRpcEncoderFilter(ActiveRpc& parent, ThriftFilters::EncoderFilterSharedPtr filter) + : ActiveRpcFilterBase(parent), encoder_handle_(filter) {} + + // ThriftFilters::EncoderFilterCallbacks + void continueEncoding() override; + DecoderEventHandler* decodeEventHandler() { return encoder_handle_.get(); } + ThriftFilters::EncoderFilterSharedPtr encoder_handle_; + }; + using ActiveRpcEncoderFilterPtr = std::unique_ptr; + // ActiveRpc tracks request/response pairs. struct ActiveRpc : LinkedObject, public Event::DeferredDeletable, public DecoderEventHandler, public ThriftFilters::DecoderFilterCallbacks, + public ThriftFilters::EncoderFilterCallbacks, public ThriftFilters::FilterChainFactoryCallbacks { ActiveRpc(ConnectionManager& parent) : parent_(parent), request_timer_(new Stats::HistogramCompletableTimespanImpl( @@ -162,8 +211,8 @@ class ConnectionManager : public Network::ReadFilter, request_timer_->complete(); parent_.stats_.request_active_.dec(); - for (auto& filter : decoder_filters_) { - filter->handle_->onDestroy(); + for (auto& filter : base_filters_) { + filter->onDestroy(); } } @@ -199,6 +248,7 @@ class ConnectionManager : public Network::ReadFilter, return parent_.read_callbacks_->connection().dispatcher(); } void continueDecoding() override { parent_.continueDecoding(); } + void continueEncoding() override {} Router::RouteConstSharedPtr route() override; TransportType downstreamTransportType() const override { return parent_.decoder_->transportType(); @@ -220,10 +270,52 @@ class ConnectionManager : public Network::ReadFilter, ActiveRpcDecoderFilterPtr wrapper = std::make_unique(*this, filter); filter->setDecoderFilterCallbacks(*wrapper); LinkedList::moveIntoListBack(std::move(wrapper), decoder_filters_); + base_filters_.emplace_back(filter); + } + + void addEncoderFilter(ThriftFilters::EncoderFilterSharedPtr filter) override { + ActiveRpcEncoderFilterPtr wrapper = std::make_unique(*this, filter); + filter->setEncoderFilterCallbacks(*wrapper); + LinkedList::moveIntoList(std::move(wrapper), encoder_filters_); + base_filters_.emplace_back(filter); + } + + void addBidirectionalFilter(ThriftFilters::BidirectionalFilterSharedPtr filter) override { + ThriftFilters::BidirectionalFilterWrapperSharedPtr wrapper = + std::make_unique(filter); + + ActiveRpcDecoderFilterPtr decoder_wrapper = + std::make_unique(*this, wrapper->decoder_filter_); + filter->setDecoderFilterCallbacks(*decoder_wrapper); + LinkedList::moveIntoListBack(std::move(decoder_wrapper), decoder_filters_); + + ActiveRpcEncoderFilterPtr encoder_wrapper = + std::make_unique(*this, wrapper->encoder_filter_); + filter->setEncoderFilterCallbacks(*encoder_wrapper); + LinkedList::moveIntoList(std::move(encoder_wrapper), encoder_filters_); + + base_filters_.emplace_back(wrapper); } bool passthroughSupported() const; - FilterStatus applyDecoderFilters(ActiveRpcDecoderFilter* filter); + + // Apply filters to the decoder_event. + // @param filter the last filter which is already applied to the decoder_event. + // nullptr indicates none is applied and the decoder_event is applied from the + // first filter. + FilterStatus applyDecoderFilters(DecoderEvent state, absl::any data, + ActiveRpcDecoderFilter* filter = nullptr); + FilterStatus applyEncoderFilters(DecoderEvent state, absl::any data, + ProtocolConverterSharedPtr protocol_converter, + ActiveRpcEncoderFilter* filter = nullptr); + template + FilterStatus applyFilters(FilterType* filter, + std::list>& filter_list, + ProtocolConverterSharedPtr protocol_converter = nullptr); + + // Helper to setup filter_action_ and filter_context_ + void prepareFilterAction(DecoderEvent event, absl::any data); + void finalizeRequest(); void createFilterChain(); @@ -235,6 +327,8 @@ class ConnectionManager : public Network::ReadFilter, StreamInfo::StreamInfoImpl stream_info_; MessageMetadataSharedPtr metadata_; std::list decoder_filters_; + std::list encoder_filters_; + std::list base_filters_; DecoderEventHandlerSharedPtr upgrade_handler_; ResponseDecoderPtr response_decoder_; absl::optional cached_route_; diff --git a/source/extensions/filters/network/thrift_proxy/decoder_events.h b/source/extensions/filters/network/thrift_proxy/decoder_events.h index 342f65ece6856..42e4f77580b24 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder_events.h +++ b/source/extensions/filters/network/thrift_proxy/decoder_events.h @@ -17,6 +17,32 @@ enum class FilterStatus { StopIteration }; +enum class DecoderEvent { + TransportBegin, + TransportEnd, + PassthroughData, + MessageBegin, + MessageEnd, + StructBegin, + StructEnd, + FieldBegin, + FieldEnd, + BoolValue, + ByteValue, + DoubleValue, + Int16Value, + Int32Value, + Int64Value, + StringValue, + ListBegin, + ListEnd, + SetBegin, + SetEnd, + MapBegin, + MapEnd, + ContinueDecode +}; + class DecoderEventHandler { public: virtual ~DecoderEventHandler() = default; diff --git a/source/extensions/filters/network/thrift_proxy/filter_utils.cc b/source/extensions/filters/network/thrift_proxy/filter_utils.cc new file mode 100644 index 0000000000000..36ba508dcd8b8 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filter_utils.cc @@ -0,0 +1,209 @@ +#include "source/extensions/filters/network/thrift_proxy/filter_utils.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +class DelegateDecoderFilter final : public DecoderFilter { +public: + DelegateDecoderFilter(BidirectionalFilterSharedPtr filter) : parent_(filter){}; + // ThriftBaseFilter + void onDestroy() override { throw EnvoyException(fmt::format("should not be called")); } + + void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) override { + return parent_->setDecoderFilterCallbacks(callbacks); + }; + + // Thrift Decoder State Machine + ThriftProxy::FilterStatus + transportBegin(ThriftProxy::MessageMetadataSharedPtr metadata) override { + return parent_->decodeTransportBegin(metadata); + } + + ThriftProxy::FilterStatus transportEnd() override { return parent_->decodeTransportEnd(); } + + bool passthroughSupported() const override { return parent_->decodePassthroughSupported(); } + + ThriftProxy::FilterStatus passthroughData(Buffer::Instance& data) override { + return parent_->decodePassthroughData(data); + } + + ThriftProxy::FilterStatus messageBegin(ThriftProxy::MessageMetadataSharedPtr metadata) override { + return parent_->decodeMessageBegin(metadata); + } + + ThriftProxy::FilterStatus messageEnd() override { return parent_->decodeMessageEnd(); } + + ThriftProxy::FilterStatus structBegin(absl::string_view name) override { + return parent_->decodeStructBegin(name); + } + + ThriftProxy::FilterStatus structEnd() override { return parent_->decodeStructEnd(); } + + ThriftProxy::FilterStatus fieldBegin(absl::string_view name, ThriftProxy::FieldType& field_type, + int16_t& field_id) override { + return parent_->decodeFieldBegin(name, field_type, field_id); + } + + ThriftProxy::FilterStatus fieldEnd() override { return parent_->decodeFieldEnd(); } + + ThriftProxy::FilterStatus boolValue(bool& value) override { + return parent_->decodeBoolValue(value); + } + + ThriftProxy::FilterStatus byteValue(uint8_t& value) override { + return parent_->decodeByteValue(value); + } + + ThriftProxy::FilterStatus int16Value(int16_t& value) override { + return parent_->decodeInt16Value(value); + } + + ThriftProxy::FilterStatus int32Value(int32_t& value) override { + return parent_->decodeInt32Value(value); + } + + ThriftProxy::FilterStatus int64Value(int64_t& value) override { + return parent_->decodeInt64Value(value); + } + + ThriftProxy::FilterStatus doubleValue(double& value) override { + return parent_->decodeDoubleValue(value); + } + + ThriftProxy::FilterStatus stringValue(absl::string_view value) override { + return parent_->decodeStringValue(value); + } + + ThriftProxy::FilterStatus mapBegin(ThriftProxy::FieldType& key_type, + ThriftProxy::FieldType& value_type, uint32_t& size) override { + return parent_->decodeMapBegin(key_type, value_type, size); + } + + ThriftProxy::FilterStatus mapEnd() override { return parent_->decodeMapEnd(); } + + ThriftProxy::FilterStatus listBegin(ThriftProxy::FieldType& elem_type, uint32_t& size) override { + return parent_->decodeListBegin(elem_type, size); + } + + ThriftProxy::FilterStatus listEnd() override { return parent_->decodeListEnd(); } + + ThriftProxy::FilterStatus setBegin(ThriftProxy::FieldType& elem_type, uint32_t& size) override { + return parent_->decodeSetBegin(elem_type, size); + } + + ThriftProxy::FilterStatus setEnd() override { return parent_->decodeSetEnd(); } + +private: + BidirectionalFilterSharedPtr parent_; +}; + +using DelegateDecoderFilterSharedPtr = std::shared_ptr; + +class DelegateEncoderFilter final : public EncoderFilter { +public: + DelegateEncoderFilter(BidirectionalFilterSharedPtr filter) : parent_(filter){}; + // ThriftBaseFilter + void onDestroy() override { throw EnvoyException(fmt::format("should not be called")); } + + void setEncoderFilterCallbacks(EncoderFilterCallbacks& callbacks) override { + return parent_->setEncoderFilterCallbacks(callbacks); + }; + + // Thrift Encoder State Machine + ThriftProxy::FilterStatus + transportBegin(ThriftProxy::MessageMetadataSharedPtr metadata) override { + return parent_->encodeTransportBegin(metadata); + } + + ThriftProxy::FilterStatus transportEnd() override { return parent_->encodeTransportEnd(); } + + bool passthroughSupported() const override { return parent_->encodePassthroughSupported(); } + + ThriftProxy::FilterStatus passthroughData(Buffer::Instance& data) override { + return parent_->encodePassthroughData(data); + } + + ThriftProxy::FilterStatus messageBegin(ThriftProxy::MessageMetadataSharedPtr metadata) override { + return parent_->encodeMessageBegin(metadata); + } + + ThriftProxy::FilterStatus messageEnd() override { return parent_->encodeMessageEnd(); } + + ThriftProxy::FilterStatus structBegin(absl::string_view name) override { + return parent_->encodeStructBegin(name); + } + + ThriftProxy::FilterStatus structEnd() override { return parent_->encodeStructEnd(); } + + ThriftProxy::FilterStatus fieldBegin(absl::string_view name, ThriftProxy::FieldType& field_type, + int16_t& field_id) override { + return parent_->encodeFieldBegin(name, field_type, field_id); + } + + ThriftProxy::FilterStatus fieldEnd() override { return parent_->encodeFieldEnd(); } + + ThriftProxy::FilterStatus boolValue(bool& value) override { + return parent_->encodeBoolValue(value); + } + + ThriftProxy::FilterStatus byteValue(uint8_t& value) override { + return parent_->encodeByteValue(value); + } + + ThriftProxy::FilterStatus int16Value(int16_t& value) override { + return parent_->encodeInt16Value(value); + } + + ThriftProxy::FilterStatus int32Value(int32_t& value) override { + return parent_->encodeInt32Value(value); + } + + ThriftProxy::FilterStatus int64Value(int64_t& value) override { + return parent_->encodeInt64Value(value); + } + + ThriftProxy::FilterStatus doubleValue(double& value) override { + return parent_->encodeDoubleValue(value); + } + + ThriftProxy::FilterStatus stringValue(absl::string_view value) override { + return parent_->encodeStringValue(value); + } + + ThriftProxy::FilterStatus mapBegin(ThriftProxy::FieldType& key_type, + ThriftProxy::FieldType& value_type, uint32_t& size) override { + return parent_->encodeMapBegin(key_type, value_type, size); + } + + ThriftProxy::FilterStatus mapEnd() override { return parent_->encodeMapEnd(); } + + ThriftProxy::FilterStatus listBegin(ThriftProxy::FieldType& elem_type, uint32_t& size) override { + return parent_->encodeListBegin(elem_type, size); + } + + ThriftProxy::FilterStatus listEnd() override { return parent_->encodeListEnd(); } + + ThriftProxy::FilterStatus setBegin(ThriftProxy::FieldType& elem_type, uint32_t& size) override { + return parent_->encodeSetBegin(elem_type, size); + } + + ThriftProxy::FilterStatus setEnd() override { return parent_->encodeSetEnd(); } + +private: + BidirectionalFilterSharedPtr parent_; +}; + +using DelegateEncoderFilterSharedPtr = std::shared_ptr; + +BidirectionalFilterWrapper::BidirectionalFilterWrapper(BidirectionalFilterSharedPtr filter) + : decoder_filter_(std::make_shared(filter)), + encoder_filter_(std::make_shared(filter)), parent_(filter) {} + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filter_utils.h b/source/extensions/filters/network/thrift_proxy/filter_utils.h new file mode 100644 index 0000000000000..b4e0c315c3366 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/filter_utils.h @@ -0,0 +1,33 @@ +#pragma once + +#include "source/extensions/filters/network/thrift_proxy/filters/filter.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace ThriftFilters { + +class BidirectionalFilterWrapper final : public FilterBase { +public: + BidirectionalFilterWrapper(BidirectionalFilterSharedPtr filter); + + // ThriftBaseFilter + void onDestroy() override { parent_->onDestroy(); } + + DecoderFilterSharedPtr decoder_filter_; + EncoderFilterSharedPtr encoder_filter_; + +private: + BidirectionalFilterSharedPtr parent_; +}; + +using BidirectionalFilterWrapperSharedPtr = std::shared_ptr; + +} // namespace ThriftFilters +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/filters/filter.h b/source/extensions/filters/network/thrift_proxy/filters/filter.h index 13f509bdd6795..a4a8d57be6763 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/filter.h @@ -27,11 +27,11 @@ enum class ResponseStatus { }; /** - * Decoder filter callbacks add additional callbacks. + * Common interface for FilterDecoderCallbacks and FilterEncoderCallbacks. */ -class DecoderFilterCallbacks { +class FilterCallbacks { public: - virtual ~DecoderFilterCallbacks() = default; + virtual ~FilterCallbacks() = default; /** * @return uint64_t the ID of the originating stream for logging purposes. @@ -48,15 +48,6 @@ class DecoderFilterCallbacks { */ virtual Event::Dispatcher& dispatcher() PURE; - /** - * Continue iterating through the filter chain with buffered data. This routine can only be - * called if the filter has previously returned StopIteration from one of the DecoderFilter - * methods. The connection manager will callbacks to the next filter in the chain. Further note - * that if the request is not complete, the calling filter may receive further callbacks and must - * return an appropriate status code depending on what the filter needs to do. - */ - virtual void continueDecoding() PURE; - /** * @return RouteConstSharedPtr the route for the current request. */ @@ -73,11 +64,38 @@ class DecoderFilterCallbacks { virtual ProtocolType downstreamProtocolType() const PURE; /** - * Create a locally generated response using the provided response object. - * @param response DirectResponse the response to send to the downstream client - * @param end_stream if true, the downstream connection should be closed after this response + * Reset the downstream connection. */ - virtual void sendLocalReply(const ThriftProxy::DirectResponse& response, bool end_stream) PURE; + virtual void resetDownstreamConnection() PURE; + + /** + * @return StreamInfo for logging purposes. + */ + virtual StreamInfo::StreamInfo& streamInfo() PURE; + + /** + * @return Response decoder metadata created by the connection manager. + */ + virtual MessageMetadataSharedPtr responseMetadata() PURE; + + /** + * @return Signal indicating whether or not the response decoder encountered a successful/void + * reply. + */ + virtual bool responseSuccess() PURE; + + /** + * Called when upstream connection gets reset. + */ + virtual void onReset() PURE; +}; + +/** + * Decoder filter callbacks add additional callbacks. + */ +class DecoderFilterCallbacks : public virtual FilterCallbacks { +public: + ~DecoderFilterCallbacks() override = default; /** * Indicates the start of an upstream response. May only be called once. @@ -95,39 +113,41 @@ class DecoderFilterCallbacks { virtual ResponseStatus upstreamData(Buffer::Instance& data) PURE; /** - * Reset the downstream connection. - */ - virtual void resetDownstreamConnection() PURE; - - /** - * @return StreamInfo for logging purposes. + * Create a locally generated response using the provided response object. + * @param response DirectResponse the response to send to the downstream client + * @param end_stream if true, the downstream connection should be closed after this response */ - virtual StreamInfo::StreamInfo& streamInfo() PURE; + virtual void sendLocalReply(const ThriftProxy::DirectResponse& response, bool end_stream) PURE; /** - * @return Response decoder metadata created by the connection manager. + * Continue iterating through the filter chain with buffered data. This routine can only be + * called if the filter has previously returned StopIteration from one of the DecoderFilter + * methods. The connection manager will callbacks to the next filter in the chain. Further note + * that if the request is not complete, the calling filter may receive further callbacks and must + * return an appropriate status code depending on what the filter needs to do. */ - virtual MessageMetadataSharedPtr responseMetadata() PURE; + virtual void continueDecoding() PURE; +}; - /** - * @return Signal indicating whether or not the response decoder encountered a successful/void - * reply. - */ - virtual bool responseSuccess() PURE; +/** + * Encoder filter callbacks add additional callbacks. + */ +class EncoderFilterCallbacks : public virtual FilterCallbacks { +public: + ~EncoderFilterCallbacks() override = default; /** - * Called when upstream connection gets reset. + * Currently throw given we don't support StopIteration for EncoderFilter yet. */ - virtual void onReset() PURE; + virtual void continueEncoding() PURE; }; /** - * Decoder filter interface. + * Common interface for Thrift filters. */ -class DecoderFilter : public virtual DecoderEventHandler { +class FilterBase { public: - ~DecoderFilter() override = default; - + virtual ~FilterBase() = default; /** * This routine is called prior to a filter being destroyed. This may happen after normal stream * finish (both downstream and upstream) or due to reset. Every filter is responsible for making @@ -138,6 +158,16 @@ class DecoderFilter : public virtual DecoderEventHandler { * onDestroy() invoked. */ virtual void onDestroy() PURE; +}; + +using FilterBaseSharedPtr = std::shared_ptr; + +/** + * Decoder filter interface. + */ +class DecoderFilter : public FilterBase, public virtual DecoderEventHandler { +public: + ~DecoderFilter() override = default; /** * Called by the connection manager once to initialize the filter decoder callbacks that the @@ -154,6 +184,101 @@ class DecoderFilter : public virtual DecoderEventHandler { using DecoderFilterSharedPtr = std::shared_ptr; +/** + * Encoder filter interface. + * + * Currently the EncoderFilter and BidirectionalFilter support + * a. peek the metadata and content for encoding, + * b. modify the metadata and content except string value for encoding, and + * c. what DecoderFilter supports for BidirectionalFilter. + * + * Do not support + * a. pass through data separately for encode and decode, e.g., pass through data for request but + * not pass through data for response, and + * b. return StopIteration for decoder_events in for encoder filter and encode* events for + * bidirectional filter. The filters trying to return StopIteration will reset the connection. + */ +class EncoderFilter : public FilterBase, public virtual DecoderEventHandler { +public: + ~EncoderFilter() override = default; + + /** + * Called by the connection manager once to initialize the filter encoder callbacks that the + * filter should use. Callbacks will not be invoked by the filter after onDestroy() is called. + */ + virtual void setEncoderFilterCallbacks(EncoderFilterCallbacks& callbacks) PURE; + + /** + * @return True if payload passthrough is supported. Called by the connection manager once after + * messageBegin. + */ + virtual bool passthroughSupported() const PURE; +}; + +using EncoderFilterSharedPtr = std::shared_ptr; + +/** + * Bidirectional filter interface. @see EncoderFilter for limitation. + */ +class BidirectionalFilter : public FilterBase { +public: + ~BidirectionalFilter() override = default; + virtual void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) PURE; + virtual void setEncoderFilterCallbacks(EncoderFilterCallbacks& callbacks) PURE; + virtual bool decodePassthroughSupported() const PURE; + virtual bool encodePassthroughSupported() const PURE; + virtual FilterStatus decodeTransportBegin(MessageMetadataSharedPtr metadata) PURE; + virtual FilterStatus encodeTransportBegin(MessageMetadataSharedPtr metadata) PURE; + virtual FilterStatus decodeTransportEnd() PURE; + virtual FilterStatus encodeTransportEnd() PURE; + virtual FilterStatus decodePassthroughData(Buffer::Instance& data) PURE; + virtual FilterStatus encodePassthroughData(Buffer::Instance& data) PURE; + virtual FilterStatus decodeMessageBegin(MessageMetadataSharedPtr metadata) PURE; + virtual FilterStatus encodeMessageBegin(MessageMetadataSharedPtr metadata) PURE; + virtual FilterStatus decodeMessageEnd() PURE; + virtual FilterStatus encodeMessageEnd() PURE; + virtual FilterStatus decodeStructBegin(absl::string_view name) PURE; + virtual FilterStatus encodeStructBegin(absl::string_view name) PURE; + virtual FilterStatus decodeStructEnd() PURE; + virtual FilterStatus encodeStructEnd() PURE; + virtual FilterStatus decodeFieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) PURE; + virtual FilterStatus encodeFieldBegin(absl::string_view name, FieldType& field_type, + int16_t& field_id) PURE; + virtual FilterStatus decodeFieldEnd() PURE; + virtual FilterStatus encodeFieldEnd() PURE; + virtual FilterStatus decodeBoolValue(bool& value) PURE; + virtual FilterStatus encodeBoolValue(bool& value) PURE; + virtual FilterStatus decodeByteValue(uint8_t& value) PURE; + virtual FilterStatus encodeByteValue(uint8_t& value) PURE; + virtual FilterStatus decodeInt16Value(int16_t& value) PURE; + virtual FilterStatus encodeInt16Value(int16_t& value) PURE; + virtual FilterStatus decodeInt32Value(int32_t& value) PURE; + virtual FilterStatus encodeInt32Value(int32_t& value) PURE; + virtual FilterStatus decodeInt64Value(int64_t& value) PURE; + virtual FilterStatus encodeInt64Value(int64_t& value) PURE; + virtual FilterStatus decodeDoubleValue(double& value) PURE; + virtual FilterStatus encodeDoubleValue(double& value) PURE; + virtual FilterStatus decodeStringValue(absl::string_view value) PURE; + virtual FilterStatus encodeStringValue(absl::string_view value) PURE; + virtual FilterStatus decodeMapBegin(FieldType& key_type, FieldType& value_type, + uint32_t& size) PURE; + virtual FilterStatus encodeMapBegin(FieldType& key_type, FieldType& value_type, + uint32_t& size) PURE; + virtual FilterStatus decodeMapEnd() PURE; + virtual FilterStatus encodeMapEnd() PURE; + virtual FilterStatus decodeListBegin(FieldType& elem_type, uint32_t& size) PURE; + virtual FilterStatus encodeListBegin(FieldType& elem_type, uint32_t& size) PURE; + virtual FilterStatus decodeListEnd() PURE; + virtual FilterStatus encodeListEnd() PURE; + virtual FilterStatus decodeSetBegin(FieldType& elem_type, uint32_t& size) PURE; + virtual FilterStatus encodeSetBegin(FieldType& elem_type, uint32_t& size) PURE; + virtual FilterStatus decodeSetEnd() PURE; + virtual FilterStatus encodeSetEnd() PURE; +}; + +using BidirectionalFilterSharedPtr = std::shared_ptr; + /** * These callbacks are provided by the connection manager to the factory so that the factory can * build the filter chain in an application specific way. @@ -167,6 +292,18 @@ class FilterChainFactoryCallbacks { * @param filter supplies the filter to add. */ virtual void addDecoderFilter(DecoderFilterSharedPtr filter) PURE; + + /** + * Add an encoder filter that is used when writing connection data. + * @param filter supplies the filter to add. + */ + virtual void addEncoderFilter(EncoderFilterSharedPtr filter) PURE; + + /** + * Add a bidirectional filter that is used when reading and writing connection data. + * @param filter supplies the filter to add. + */ + virtual void addBidirectionalFilter(BidirectionalFilterSharedPtr filter) PURE; }; /** diff --git a/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h b/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h index 5858ff46e6c16..f19344664cd72 100644 --- a/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h +++ b/source/extensions/filters/network/thrift_proxy/filters/pass_through_filter.h @@ -11,14 +11,15 @@ namespace ThriftProxy { namespace ThriftFilters { /** - * Pass through Thrift decoder filter. Continue at each decoding state within the series of - * transitions. + * Pass through Thrift decoder/encoder/bidirectional filter. Continue at each state within the + * series of transitions, and pass through the decoded/encoded data. */ class PassThroughDecoderFilter : public DecoderFilter { public: - // ThriftDecoderFilter + // Thrift FilterBase void onDestroy() override {} + // Thrift DecoderFilter void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) override { decoder_callbacks_ = &callbacks; }; @@ -106,6 +107,248 @@ class PassThroughDecoderFilter : public DecoderFilter { DecoderFilterCallbacks* decoder_callbacks_{}; }; +class PassThroughEncoderFilter : public EncoderFilter { +public: + // Thrift FilterBase + void onDestroy() override {} + + // Thrift EncoderFilter + void setEncoderFilterCallbacks(EncoderFilterCallbacks& callbacks) override { + encoder_callbacks_ = &callbacks; + }; + + // Thrift Encoder State Machine + ThriftProxy::FilterStatus transportBegin(ThriftProxy::MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus transportEnd() override { return ThriftProxy::FilterStatus::Continue; } + + bool passthroughSupported() const override { return true; } + + ThriftProxy::FilterStatus passthroughData(Buffer::Instance&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus messageBegin(ThriftProxy::MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus messageEnd() override { return ThriftProxy::FilterStatus::Continue; } + + ThriftProxy::FilterStatus structBegin(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus structEnd() override { return ThriftProxy::FilterStatus::Continue; } + + ThriftProxy::FilterStatus fieldBegin(absl::string_view, ThriftProxy::FieldType&, + int16_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus fieldEnd() override { return ThriftProxy::FilterStatus::Continue; } + + ThriftProxy::FilterStatus boolValue(bool&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus byteValue(uint8_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus int16Value(int16_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus int32Value(int32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus int64Value(int64_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus doubleValue(double&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus stringValue(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus mapBegin(ThriftProxy::FieldType&, ThriftProxy::FieldType&, + uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus mapEnd() override { return ThriftProxy::FilterStatus::Continue; } + + ThriftProxy::FilterStatus listBegin(ThriftProxy::FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus listEnd() override { return ThriftProxy::FilterStatus::Continue; } + + ThriftProxy::FilterStatus setBegin(ThriftProxy::FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + ThriftProxy::FilterStatus setEnd() override { return ThriftProxy::FilterStatus::Continue; } + +protected: + EncoderFilterCallbacks* encoder_callbacks_{}; +}; + +class PassThroughBidirectionalFilter : public BidirectionalFilter { +public: + // ThriftFilterBase + void onDestroy() override {} + + // Thrift DecoderFilter + void setDecoderFilterCallbacks(DecoderFilterCallbacks& callbacks) override { + decoder_callbacks_ = &callbacks; + } + + // Thrift EncoderFilter + void setEncoderFilterCallbacks(EncoderFilterCallbacks& callbacks) override { + encoder_callbacks_ = &callbacks; + } + + // Thrift Decoder/Encoder State Machine + bool decodePassthroughSupported() const override { return true; } + + bool encodePassthroughSupported() const override { return true; } + + FilterStatus decodeTransportBegin(MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeTransportBegin(MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeTransportEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeTransportEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodePassthroughData(Buffer::Instance&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodePassthroughData(Buffer::Instance&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeMessageBegin(MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeMessageBegin(MessageMetadataSharedPtr) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeMessageEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeMessageEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeStructBegin(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeStructBegin(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeStructEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeStructEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeFieldBegin(absl::string_view, FieldType&, int16_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeFieldBegin(absl::string_view, FieldType&, int16_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeFieldEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeFieldEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeBoolValue(bool&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeBoolValue(bool&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeByteValue(uint8_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeByteValue(uint8_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeInt16Value(int16_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeInt16Value(int16_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeInt32Value(int32_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeInt32Value(int32_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeInt64Value(int64_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeInt64Value(int64_t&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeDoubleValue(double&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeDoubleValue(double&) override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeStringValue(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeStringValue(absl::string_view) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeMapBegin(FieldType&, FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeMapBegin(FieldType&, FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeMapEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeMapEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeListBegin(FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeListBegin(FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeListEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeListEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus decodeSetBegin(FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus encodeSetBegin(FieldType&, uint32_t&) override { + return ThriftProxy::FilterStatus::Continue; + } + + FilterStatus decodeSetEnd() override { return ThriftProxy::FilterStatus::Continue; } + + FilterStatus encodeSetEnd() override { return ThriftProxy::FilterStatus::Continue; } + +protected: + EncoderFilterCallbacks* encoder_callbacks_{}; + DecoderFilterCallbacks* decoder_callbacks_{}; +}; + } // namespace ThriftFilters } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/source/extensions/filters/network/thrift_proxy/protocol_converter.h b/source/extensions/filters/network/thrift_proxy/protocol_converter.h index 62de1a1f9fa75..a24d8d7f2ecf7 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_converter.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_converter.h @@ -11,7 +11,7 @@ namespace NetworkFilters { namespace ThriftProxy { /** - * ProtocolConverter is an abstract class that implements protocol-related methods on + * ProtocolConverter is an class that implements protocol-related methods on * DecoderEventHandler in terms of converting the decoded messages into a different protocol. */ class ProtocolConverter : public virtual DecoderEventHandler { @@ -30,6 +30,10 @@ class ProtocolConverter : public virtual DecoderEventHandler { return FilterStatus::Continue; } + FilterStatus transportBegin(MessageMetadataSharedPtr) override { return FilterStatus::Continue; } + + FilterStatus transportEnd() override { return FilterStatus::Continue; } + FilterStatus messageBegin(MessageMetadataSharedPtr metadata) override { proto_->writeMessageBegin(*buffer_, *metadata); return FilterStatus::Continue; @@ -132,6 +136,7 @@ class ProtocolConverter : public virtual DecoderEventHandler { Buffer::Instance* buffer_{}; }; +using ProtocolConverterSharedPtr = std::shared_ptr; } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 0c9215c5a3289..49275e657ea03 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -41,17 +41,26 @@ class TestConfigImpl : public ConfigImpl { TestConfigImpl(envoy::extensions::filters::network::thrift_proxy::v3::ThriftProxy proto_config, Server::Configuration::MockFactoryContext& context, Router::RouteConfigProviderManager& route_config_provider_manager, - ThriftFilters::DecoderFilterSharedPtr decoder_filter, ThriftFilterStats& stats) + ThriftFilters::DecoderFilterSharedPtr decoder_filter, + ThriftFilters::EncoderFilterSharedPtr encoder_filter, + ThriftFilters::BidirectionalFilterSharedPtr bidirectional_filter, + ThriftFilterStats& stats) : ConfigImpl(proto_config, context, route_config_provider_manager), - decoder_filter_(decoder_filter), stats_(stats) {} + decoder_filter_(decoder_filter), encoder_filter_(encoder_filter), + bidirectional_filter_(bidirectional_filter), stats_(stats) {} // ConfigImpl ThriftFilterStats& stats() override { return stats_; } void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override { - if (custom_filter_) { - callbacks.addDecoderFilter(custom_filter_); + if (custom_decoder_filter_) { + callbacks.addDecoderFilter(custom_decoder_filter_); } callbacks.addDecoderFilter(decoder_filter_); + if (custom_encoder_filter_) { + callbacks.addEncoderFilter(custom_encoder_filter_); + } + callbacks.addEncoderFilter(encoder_filter_); + callbacks.addBidirectionalFilter(bidirectional_filter_); } TransportPtr createTransport() override { if (transport_) { @@ -66,8 +75,11 @@ class TestConfigImpl : public ConfigImpl { return ConfigImpl::createProtocol(); } - ThriftFilters::DecoderFilterSharedPtr custom_filter_; + ThriftFilters::DecoderFilterSharedPtr custom_decoder_filter_; ThriftFilters::DecoderFilterSharedPtr decoder_filter_; + ThriftFilters::EncoderFilterSharedPtr custom_encoder_filter_; + ThriftFilters::EncoderFilterSharedPtr encoder_filter_; + ThriftFilters::BidirectionalFilterSharedPtr bidirectional_filter_; ThriftFilterStats& stats_; MockTransport* transport_{}; MockProtocol* protocol_{}; @@ -83,6 +95,13 @@ class ThriftConnectionManagerTest : public testing::Test { filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); } + void initializeFilterWithCustomFilters() { + auto* decoder_filter = new NiceMock(); + custom_decoder_filter_.reset(decoder_filter); + auto* encoder_filter = new NiceMock(); + custom_encoder_filter_.reset(encoder_filter); + initializeFilter(); + } void initializeFilter() { initializeFilter(""); } void initializeFilter(const std::string& yaml, @@ -113,17 +132,24 @@ class ThriftConnectionManagerTest : public testing::Test { context_.server_factory_context_.cluster_manager_.initializeClusters(cluster_names, {}); proto_config_ = config; + decoder_filter_ = std::make_shared>(); + encoder_filter_ = std::make_shared>(); + bidirectional_filter_ = std::make_shared>(); - config_ = std::make_unique( - proto_config_, context_, *route_config_provider_manager_, decoder_filter_, stats_); + config_ = std::make_unique(proto_config_, context_, + *route_config_provider_manager_, decoder_filter_, + encoder_filter_, bidirectional_filter_, stats_); if (custom_transport_) { config_->transport_ = custom_transport_; } if (custom_protocol_) { config_->protocol_ = custom_protocol_; } - if (custom_filter_) { - config_->custom_filter_ = custom_filter_; + if (custom_decoder_filter_) { + config_->custom_decoder_filter_ = custom_decoder_filter_; + } + if (custom_encoder_filter_) { + config_->custom_encoder_filter_ = custom_encoder_filter_; } ON_CALL(random_, random()).WillByDefault(Return(42)); @@ -337,6 +363,7 @@ class ThriftConnectionManagerTest : public testing::Test { proto->writeMessageBegin(msg, metadata); proto->writeStructBegin(msg, ""); + // successful response struct in field id 0, error (IDL exception) in field id greater than 0 proto->writeFieldBegin(msg, "", FieldType::Struct, 2); proto->writeStructBegin(msg, ""); @@ -368,6 +395,8 @@ class ThriftConnectionManagerTest : public testing::Test { initializeFilter(); writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + checkDecoderEventsCalledToFilters(MessageType::Call, 0x0F, MessageType::Reply, 0x0F); + ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( @@ -412,8 +441,177 @@ class ThriftConnectionManagerTest : public testing::Test { EXPECT_EQ(draining ? 1U : 0U, store_.counter("test.downstream_response_drain_close").value()); } + void checkDecoderEventsCalledToFilters(MessageType req_msg_type, int32_t req_seq_id, + MessageType resp_msg_type, int32_t resp_seq_id) { + bool one = true; + uint8_t two = 2; + double three = 3.0; + int16_t four = 4; + int32_t five = 5; + int64_t six = 6; + int32_t eight = 8; + FieldType field_type_i32 = FieldType::I32; + + uint32_t one_32 = 1; + EXPECT_CALL(*decoder_filter_, messageBegin(_)) + .WillOnce( + Invoke([req_msg_type, req_seq_id](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(req_msg_type, metadata->messageType()); + EXPECT_EQ(req_seq_id, metadata->sequenceId()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*decoder_filter_, messageEnd()); + // The struct name is not available at runtime. + EXPECT_CALL(*decoder_filter_, structBegin("")).Times(2); + EXPECT_CALL(*decoder_filter_, structEnd()).Times(2); + // The field name is not available at runtime. + EXPECT_CALL(*decoder_filter_, fieldBegin("", _, _)).Times(11); + EXPECT_CALL(*decoder_filter_, fieldEnd()).Times(11); + EXPECT_CALL(*decoder_filter_, boolValue(one)); + EXPECT_CALL(*decoder_filter_, byteValue(two)); + EXPECT_CALL(*decoder_filter_, doubleValue(three)); + EXPECT_CALL(*decoder_filter_, int16Value(four)); + EXPECT_CALL(*decoder_filter_, int32Value(five)); + EXPECT_CALL(*decoder_filter_, int64Value(six)); + EXPECT_CALL(*decoder_filter_, stringValue("seven")); + EXPECT_CALL(*decoder_filter_, mapBegin(field_type_i32, field_type_i32, one_32)); + EXPECT_CALL(*decoder_filter_, int32Value(eight)).Times(4); + EXPECT_CALL(*decoder_filter_, mapEnd()); + EXPECT_CALL(*decoder_filter_, listBegin(field_type_i32, one_32)); + EXPECT_CALL(*decoder_filter_, listEnd()); + EXPECT_CALL(*decoder_filter_, setBegin(field_type_i32, one_32)); + EXPECT_CALL(*decoder_filter_, setEnd()); + + EXPECT_CALL(*encoder_filter_, transportBegin(_)); + EXPECT_CALL(*encoder_filter_, transportEnd()); + EXPECT_CALL(*encoder_filter_, messageBegin(_)) + .WillOnce( + Invoke([resp_msg_type, resp_seq_id](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(resp_msg_type, metadata->messageType()); + EXPECT_EQ(resp_seq_id, metadata->sequenceId()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*encoder_filter_, messageEnd()); + EXPECT_CALL(*encoder_filter_, structBegin("")).Times(2); + EXPECT_CALL(*encoder_filter_, structEnd()).Times(2); + EXPECT_CALL(*encoder_filter_, fieldBegin("", _, _)).Times(11); + EXPECT_CALL(*encoder_filter_, fieldEnd()).Times(11); + EXPECT_CALL(*encoder_filter_, boolValue(one)); + EXPECT_CALL(*encoder_filter_, byteValue(two)); + EXPECT_CALL(*encoder_filter_, doubleValue(three)); + EXPECT_CALL(*encoder_filter_, int16Value(four)); + EXPECT_CALL(*encoder_filter_, int32Value(five)); + EXPECT_CALL(*encoder_filter_, int64Value(six)); + EXPECT_CALL(*encoder_filter_, stringValue("seven")); + EXPECT_CALL(*encoder_filter_, mapBegin(field_type_i32, field_type_i32, one_32)); + EXPECT_CALL(*encoder_filter_, int32Value(eight)).Times(4); + EXPECT_CALL(*encoder_filter_, mapEnd()); + EXPECT_CALL(*encoder_filter_, listBegin(field_type_i32, one_32)); + EXPECT_CALL(*encoder_filter_, listEnd()); + EXPECT_CALL(*encoder_filter_, setBegin(field_type_i32, one_32)); + EXPECT_CALL(*encoder_filter_, setEnd()); + + EXPECT_CALL(*bidirectional_filter_, decodeMessageBegin(_)) + .WillOnce( + Invoke([req_msg_type, req_seq_id](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(req_msg_type, metadata->messageType()); + EXPECT_EQ(req_seq_id, metadata->sequenceId()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*bidirectional_filter_, decodeMessageEnd()); + EXPECT_CALL(*bidirectional_filter_, decodeStructBegin("")).Times(2); + EXPECT_CALL(*bidirectional_filter_, decodeStructEnd()).Times(2); + EXPECT_CALL(*bidirectional_filter_, decodeFieldBegin("", _, _)).Times(11); + EXPECT_CALL(*bidirectional_filter_, decodeFieldEnd()).Times(11); + EXPECT_CALL(*bidirectional_filter_, decodeBoolValue(one)); + EXPECT_CALL(*bidirectional_filter_, decodeByteValue(two)); + EXPECT_CALL(*bidirectional_filter_, decodeDoubleValue(three)); + EXPECT_CALL(*bidirectional_filter_, decodeInt16Value(four)); + EXPECT_CALL(*bidirectional_filter_, decodeInt32Value(five)); + EXPECT_CALL(*bidirectional_filter_, decodeInt64Value(six)); + EXPECT_CALL(*bidirectional_filter_, decodeStringValue("seven")); + EXPECT_CALL(*bidirectional_filter_, decodeMapBegin(field_type_i32, field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, decodeInt32Value(eight)).Times(4); + EXPECT_CALL(*bidirectional_filter_, decodeMapEnd()); + EXPECT_CALL(*bidirectional_filter_, decodeListBegin(field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, decodeListEnd()); + EXPECT_CALL(*bidirectional_filter_, decodeSetBegin(field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, decodeSetEnd()); + + EXPECT_CALL(*bidirectional_filter_, encodeTransportBegin(_)); + EXPECT_CALL(*bidirectional_filter_, encodeTransportEnd()); + EXPECT_CALL(*bidirectional_filter_, encodeMessageBegin(_)) + .WillOnce( + Invoke([resp_msg_type, resp_seq_id](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(resp_msg_type, metadata->messageType()); + EXPECT_EQ(resp_seq_id, metadata->sequenceId()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*bidirectional_filter_, encodeMessageEnd()); + EXPECT_CALL(*bidirectional_filter_, encodeStructBegin("")).Times(2); + EXPECT_CALL(*bidirectional_filter_, encodeStructEnd()).Times(2); + EXPECT_CALL(*bidirectional_filter_, encodeFieldBegin("", _, _)).Times(11); + EXPECT_CALL(*bidirectional_filter_, encodeFieldEnd()).Times(11); + EXPECT_CALL(*bidirectional_filter_, encodeBoolValue(one)); + EXPECT_CALL(*bidirectional_filter_, encodeByteValue(two)); + EXPECT_CALL(*bidirectional_filter_, encodeDoubleValue(three)); + EXPECT_CALL(*bidirectional_filter_, encodeInt16Value(four)); + EXPECT_CALL(*bidirectional_filter_, encodeInt32Value(five)); + EXPECT_CALL(*bidirectional_filter_, encodeInt64Value(six)); + EXPECT_CALL(*bidirectional_filter_, encodeStringValue("seven")); + EXPECT_CALL(*bidirectional_filter_, encodeMapBegin(field_type_i32, field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, encodeInt32Value(eight)).Times(4); + EXPECT_CALL(*bidirectional_filter_, encodeMapEnd()); + EXPECT_CALL(*bidirectional_filter_, encodeListBegin(field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, encodeListEnd()); + EXPECT_CALL(*bidirectional_filter_, encodeSetBegin(field_type_i32, one_32)); + EXPECT_CALL(*bidirectional_filter_, encodeSetEnd()); + } + + void passthroughSupportedSetup(bool expected_decode_passthrough_data_called = true, + bool expected_encode_passthrough_data_called = true) { + EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*encoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*bidirectional_filter_, decodePassthroughSupported()).WillRepeatedly(Return(true)); + EXPECT_CALL(*bidirectional_filter_, encodePassthroughSupported()).WillRepeatedly(Return(true)); + + if (expected_decode_passthrough_data_called) { + EXPECT_CALL(*decoder_filter_, passthroughData(_)); + EXPECT_CALL(*bidirectional_filter_, decodePassthroughData(_)); + } else { + EXPECT_CALL(*decoder_filter_, passthroughData(_)).Times(0); + EXPECT_CALL(*bidirectional_filter_, decodePassthroughData(_)).Times(0); + } + + if (expected_encode_passthrough_data_called) { + EXPECT_CALL(*encoder_filter_, passthroughData(_)); + EXPECT_CALL(*bidirectional_filter_, encodePassthroughData(_)); + } else { + EXPECT_CALL(*encoder_filter_, passthroughData(_)).Times(0); + EXPECT_CALL(*bidirectional_filter_, encodePassthroughData(_)).Times(0); + } + } + NiceMock context_; std::shared_ptr decoder_filter_; + std::shared_ptr encoder_filter_; + std::shared_ptr bidirectional_filter_; Stats::TestUtil::TestStore store_; ThriftFilterStats stats_; envoy::extensions::filters::network::thrift_proxy::v3::ThriftProxy proto_config_; @@ -429,7 +627,8 @@ class ThriftConnectionManagerTest : public testing::Test { std::unique_ptr filter_; MockTransport* custom_transport_{}; MockProtocol* custom_protocol_{}; - ThriftFilters::DecoderFilterSharedPtr custom_filter_; + std::shared_ptr custom_decoder_filter_; + std::shared_ptr custom_encoder_filter_; }; TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCall) { @@ -1132,7 +1331,7 @@ TEST_F(ThriftConnectionManagerTest, ResetDownstreamConnection) { // Test the base case where there is no limit on the number of requests. TEST_F(ThriftConnectionManagerTest, RequestWithNoMaxRequestsLimit) { - initializeFilter(""); + initializeFilter(); EXPECT_EQ(0, config_->maxRequestsPerConnection()); EXPECT_EQ(50, sendRequests(50)); @@ -1369,41 +1568,75 @@ TEST_F(ThriftConnectionManagerTest, DownstreamProtocolUpgrade) { // Tests multiple filters are invoked in the correct order. TEST_F(ThriftConnectionManagerTest, OnDataHandlesThriftCallWithMultipleFilters) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); - initializeFilter(); + initializeFilterWithCustomFilters(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + InSequence s; - EXPECT_CALL(*filter, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); EXPECT_CALL(*decoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*filter, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*bidirectional_filter_, decodeMessageBegin(_)) + .WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_decoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); EXPECT_CALL(*decoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*bidirectional_filter_, decodeMessageEnd()).WillOnce(Return(FilterStatus::Continue)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(1U, store_.counter("test.request").value()); EXPECT_EQ(1U, store_.counter("test.request_call").value()); EXPECT_EQ(1U, stats_.request_active_.value()); + + // Reverse order for encoder filters. + EXPECT_CALL(*bidirectional_filter_, encodeMessageBegin(_)) + .WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*encoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_encoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*bidirectional_filter_, encodeMessageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*encoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_encoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + + writeFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x01); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); } // Tests stop iteration/resume with multiple filters. TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilter) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); - // First filter stops iteration. + ThriftFilters::EncoderFilterCallbacks* encoder_callbacks{}; + EXPECT_CALL(*custom_encoder_filter_, setEncoderFilterCallbacks(_)) + .WillOnce(Invoke( + [&](ThriftFilters::EncoderFilterCallbacks& cb) -> void { encoder_callbacks = &cb; })); + + // First decoder filter stops iteration. { - EXPECT_CALL(*filter, messageBegin(_)).WillOnce(Return(FilterStatus::StopIteration)); + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) + .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(0U, store_.counter("test.request").value()); EXPECT_EQ(1U, stats_.request_active_.value()); @@ -1413,8 +1646,12 @@ TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilter) { { InSequence s; EXPECT_CALL(*decoder_filter_, messageBegin(_)).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*filter, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*bidirectional_filter_, decodeMessageBegin(_)) + .WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_decoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); EXPECT_CALL(*decoder_filter_, messageEnd()).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*bidirectional_filter_, decodeMessageEnd()) + .WillOnce(Return(FilterStatus::Continue)); callbacks->continueDecoding(); } @@ -1426,14 +1663,12 @@ TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilter) { // Tests stop iteration/resume with multiple filters when iteration is stopped during // transportEnd. TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilterOnTransportEnd) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); @@ -1441,9 +1676,11 @@ TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilterOnTransportEnd) { // First filter stops iteration. { InSequence s; - EXPECT_CALL(*filter, transportBegin(_)).WillOnce(Return(FilterStatus::Continue)); + EXPECT_CALL(*custom_decoder_filter_, transportBegin(_)) + .WillOnce(Return(FilterStatus::Continue)); EXPECT_CALL(*decoder_filter_, transportBegin(_)).WillOnce(Return(FilterStatus::Continue)); - EXPECT_CALL(*filter, transportEnd()).WillOnce(Return(FilterStatus::StopIteration)); + EXPECT_CALL(*custom_decoder_filter_, transportEnd()) + .WillOnce(Return(FilterStatus::StopIteration)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(0U, store_.counter("test.request").value()); EXPECT_EQ(1U, stats_.request_active_.value()); @@ -1463,14 +1700,12 @@ TEST_F(ThriftConnectionManagerTest, OnDataResumesWithNextFilterOnTransportEnd) { // Tests multiple filters where one invokes sendLocalReply with a successful reply. TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); @@ -1484,7 +1719,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) { })); // First filter sends local reply. - EXPECT_CALL(*filter, messageBegin(_)) + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr) -> FilterStatus { callbacks->sendLocalReply(direct_response, false); return FilterStatus::StopIteration; @@ -1507,14 +1742,12 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalReply) { // Tests multiple filters where one invokes sendLocalReply with an error reply. TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); @@ -1528,7 +1761,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) { })); // First filter sends local reply. - EXPECT_CALL(*filter, messageBegin(_)) + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr) -> FilterStatus { callbacks->sendLocalReply(direct_response, false); return FilterStatus::StopIteration; @@ -1551,14 +1784,12 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendsLocalErrorReply) { // sendLocalReply does nothing, when the remote closed the connection. TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendLocalReplyRemoteClosedConnection) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); @@ -1567,7 +1798,7 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendLocalReplyRemoteClosedCo EXPECT_CALL(direct_response, encode(_, _, _)).Times(0); // First filter sends local reply. - EXPECT_CALL(*filter, messageBegin(_)) + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr) -> FilterStatus { callbacks->sendLocalReply(direct_response, false); return FilterStatus::StopIteration; @@ -1589,21 +1820,19 @@ TEST_F(ThriftConnectionManagerTest, OnDataWithFilterSendLocalReplyRemoteClosedCo // Tests a decoder filter that modifies data. TEST_F(ThriftConnectionManagerTest, DecoderFiltersModifyRequests) { - auto* filter = new NiceMock(); - custom_filter_.reset(filter); + initializeFilterWithCustomFilters(); - initializeFilter(); - writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); ThriftFilters::DecoderFilterCallbacks* callbacks{}; - EXPECT_CALL(*filter, setDecoderFilterCallbacks(_)) + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) .WillOnce( Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); Http::LowerCaseString key{"key"}; - EXPECT_CALL(*filter, transportBegin(_)) + EXPECT_CALL(*custom_decoder_filter_, transportBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { EXPECT_THAT(*metadata, HasNoHeaders()); metadata->headers().addCopy(key, "value"); @@ -1616,8 +1845,7 @@ TEST_F(ThriftConnectionManagerTest, DecoderFiltersModifyRequests) { EXPECT_EQ("value", header[0]->value().getStringView()); return FilterStatus::Continue; })); - - EXPECT_CALL(*filter, messageBegin(_)) + EXPECT_CALL(*custom_decoder_filter_, messageBegin(_)) .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { EXPECT_EQ("name", metadata->methodName()); metadata->setMethodName("alternate"); @@ -1628,13 +1856,88 @@ TEST_F(ThriftConnectionManagerTest, DecoderFiltersModifyRequests) { EXPECT_EQ("alternate", metadata->methodName()); return FilterStatus::Continue; })); - + EXPECT_CALL(*custom_decoder_filter_, boolValue(_)) + .WillOnce(Invoke([&](bool& value) -> FilterStatus { + EXPECT_EQ(true, value); + value = false; + return FilterStatus::Continue; + })); + EXPECT_CALL(*decoder_filter_, boolValue(_)).WillOnce(Invoke([&](bool& value) -> FilterStatus { + EXPECT_EQ(false, value); + return FilterStatus::Continue; + })); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(1U, store_.counter("test.request").value()); EXPECT_EQ(1U, store_.counter("test.request_call").value()); EXPECT_EQ(1U, stats_.request_active_.value()); } +// Tests a encoder filter that modifies data. +TEST_F(ThriftConnectionManagerTest, EncoderFiltersModifyRequests) { + initializeFilterWithCustomFilters(); + + writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*custom_decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + + Http::LowerCaseString key{"key"}; + + EXPECT_CALL(*encoder_filter_, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_THAT(*metadata, HasNoHeaders()); + metadata->headers().addCopy(key, "value"); + return FilterStatus::Continue; + })); + EXPECT_CALL(*custom_encoder_filter_, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + const auto header = metadata->headers().get(key); + EXPECT_FALSE(header.empty()); + EXPECT_EQ("value", header[0]->value().getStringView()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*encoder_filter_, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_EQ("name", metadata->methodName()); + metadata->setMethodName("alternate"); + return FilterStatus::Continue; + })); + EXPECT_CALL(*custom_encoder_filter_, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> FilterStatus { + EXPECT_EQ("alternate", metadata->methodName()); + return FilterStatus::Continue; + })); + EXPECT_CALL(*encoder_filter_, boolValue(_)).WillOnce(Invoke([&](bool& value) -> FilterStatus { + EXPECT_EQ(true, value); + value = false; + return FilterStatus::Continue; + })); + EXPECT_CALL(*custom_encoder_filter_, boolValue(_)) + .WillOnce(Invoke([&](bool& value) -> FilterStatus { + EXPECT_EQ(false, value); + return FilterStatus::Continue; + })); + writeComplexFramedBinaryMessage(write_buffer_, MessageType::Reply, 0x0F); + FramedTransportImpl transport; + BinaryProtocolImpl proto; + callbacks->startUpstreamResponse(transport, proto); + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); + EXPECT_EQ(ThriftFilters::ResponseStatus::Complete, callbacks->upstreamData(write_buffer_)); + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.response").value()); + EXPECT_EQ(1U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(1U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); +} + TEST_F(ThriftConnectionManagerTest, TransportEndWhenRemoteClose) { initializeFilter(); writeComplexFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); @@ -1674,8 +1977,8 @@ payload_passthrough: true initializeFilter(yaml); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + // No response since the decoder filter stop the iteration. + passthroughSupportedSetup(true, false); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); EXPECT_EQ(0, buffer_.length()); @@ -1698,8 +2001,8 @@ payload_passthrough: true initializeFilter(yaml); writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + // No response for oneway. + passthroughSupportedSetup(true, false); EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)); EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); @@ -1724,8 +2027,7 @@ payload_passthrough: true initializeFilter(yaml); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + passthroughSupportedSetup(); ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) @@ -1767,8 +2069,7 @@ payload_passthrough: true initializeFilter(yaml); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + passthroughSupportedSetup(); ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) @@ -1810,8 +2111,7 @@ payload_passthrough: true initializeFilter(yaml); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + passthroughSupportedSetup(); ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) @@ -1862,8 +2162,8 @@ stat_prefix: test initializeFilter(yaml, {"cluster"}); writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)); + // No response since the decoder filter stop the iteration. + passthroughSupportedSetup(true, false); ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) @@ -1906,8 +2206,8 @@ stat_prefix: test initializeFilter(yaml, {"cluster"}); writeFramedBinaryMessage(buffer_, MessageType::Oneway, 0x0F); - EXPECT_CALL(*decoder_filter_, passthroughSupported()).WillRepeatedly(Return(true)); - EXPECT_CALL(*decoder_filter_, passthroughData(_)).Times(0); + // PassthroughData is not expected to be called. + passthroughSupportedSetup(false, false); ThriftFilters::DecoderFilterCallbacks* callbacks{}; EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) diff --git a/test/extensions/filters/network/thrift_proxy/filters/pass_through_filter_test.cc b/test/extensions/filters/network/thrift_proxy/filters/pass_through_filter_test.cc index abf5a15bc569e..3c6f13aa29092 100644 --- a/test/extensions/filters/network/thrift_proxy/filters/pass_through_filter_test.cc +++ b/test/extensions/filters/network/thrift_proxy/filters/pass_through_filter_test.cc @@ -42,6 +42,7 @@ TEST_F(ThriftPassThroughDecoderFilterTest, AllMethodsAreImplementedTrivially) { initialize(); EXPECT_EQ(&filter_callbacks_, filter_->decoderFilterCallbacks()); + EXPECT_TRUE(filter_->passthroughSupported()); EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->transportBegin(request_metadata_)); EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->messageBegin(request_metadata_)); @@ -111,6 +112,262 @@ TEST_F(ThriftPassThroughDecoderFilterTest, AllMethodsAreImplementedTrivially) { EXPECT_NO_THROW(filter_->onDestroy()); } +class ThriftPassThroughEncoderFilterTest : public testing::Test { +public: + class Filter : public PassThroughEncoderFilter { + public: + EncoderFilterCallbacks* encoderFilterCallbacks() { return encoder_callbacks_; } + }; + + void initialize() { + filter_ = std::make_unique(); + filter_->setEncoderFilterCallbacks(filter_callbacks_); + } + + std::unique_ptr filter_; + NiceMock filter_callbacks_; + ThriftProxy::MessageMetadataSharedPtr request_metadata_; +}; + +// Tests that each method returns ThriftProxy::FilterStatus::Continue. +TEST_F(ThriftPassThroughEncoderFilterTest, AllMethodsAreImplementedTrivially) { + initialize(); + + EXPECT_EQ(&filter_callbacks_, filter_->encoderFilterCallbacks()); + EXPECT_TRUE(filter_->passthroughSupported()); + + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->transportBegin(request_metadata_)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->messageBegin(request_metadata_)); + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->structBegin(dummy_str)); + } + { + std::string dummy_str = "dummy"; + ThriftProxy::FieldType dummy_ft{ThriftProxy::FieldType::I32}; + int16_t dummy_id{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->fieldBegin(dummy_str, dummy_ft, dummy_id)); + } + { + bool dummy_val{false}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->boolValue(dummy_val)); + } + { + uint8_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->byteValue(dummy_val)); + } + { + int16_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->int16Value(dummy_val)); + } + { + int32_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->int32Value(dummy_val)); + } + { + int64_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->int64Value(dummy_val)); + } + { + double dummy_val{0.0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->doubleValue(dummy_val)); + } + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->stringValue(dummy_str)); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->mapBegin(dummy_ft, dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->mapEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->listBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->listEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->setBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->setEnd()); + } + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->structEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->fieldEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->messageEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->transportEnd()); + + EXPECT_NO_THROW(filter_->onDestroy()); +} + +class ThriftPassThroughBidirectionalFilterTest : public testing::Test { +public: + class Filter : public PassThroughBidirectionalFilter { + public: + DecoderFilterCallbacks* decoderFilterCallbacks() { return decoder_callbacks_; } + EncoderFilterCallbacks* encoderFilterCallbacks() { return encoder_callbacks_; } + }; + + void initialize() { + filter_ = std::make_unique(); + filter_->setEncoderFilterCallbacks(encoder_filter_callbacks_); + filter_->setDecoderFilterCallbacks(decoder_filter_callbacks_); + } + + std::unique_ptr filter_; + NiceMock encoder_filter_callbacks_; + NiceMock decoder_filter_callbacks_; + ThriftProxy::MessageMetadataSharedPtr request_metadata_; +}; + +// Tests that each method returns ThriftProxy::FilterStatus::Continue. +TEST_F(ThriftPassThroughBidirectionalFilterTest, AllMethodsAreImplementedTrivially) { + initialize(); + + EXPECT_EQ(&decoder_filter_callbacks_, filter_->decoderFilterCallbacks()); + EXPECT_TRUE(filter_->decodePassthroughSupported()); + EXPECT_TRUE(filter_->encodePassthroughSupported()); + + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeTransportBegin(request_metadata_)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeMessageBegin(request_metadata_)); + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeStructBegin(dummy_str)); + } + { + std::string dummy_str = "dummy"; + ThriftProxy::FieldType dummy_ft{ThriftProxy::FieldType::I32}; + int16_t dummy_id{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->decodeFieldBegin(dummy_str, dummy_ft, dummy_id)); + } + { + bool dummy_val{false}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeBoolValue(dummy_val)); + } + { + uint8_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeByteValue(dummy_val)); + } + { + int16_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeInt16Value(dummy_val)); + } + { + int32_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeInt32Value(dummy_val)); + } + { + int64_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeInt64Value(dummy_val)); + } + { + double dummy_val{0.0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeDoubleValue(dummy_val)); + } + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeStringValue(dummy_str)); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->decodeMapBegin(dummy_ft, dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeMapEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeListBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeListEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeSetBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeSetEnd()); + } + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeStructEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeFieldEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeMessageEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->decodeTransportEnd()); + + // Encoding phase. + EXPECT_EQ(&encoder_filter_callbacks_, filter_->encoderFilterCallbacks()); + + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeTransportBegin(request_metadata_)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeMessageBegin(request_metadata_)); + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeStructBegin(dummy_str)); + } + { + std::string dummy_str = "dummy"; + ThriftProxy::FieldType dummy_ft{ThriftProxy::FieldType::I32}; + int16_t dummy_id{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->encodeFieldBegin(dummy_str, dummy_ft, dummy_id)); + } + { + bool dummy_val{false}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeBoolValue(dummy_val)); + } + { + uint8_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeByteValue(dummy_val)); + } + { + int16_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeInt16Value(dummy_val)); + } + { + int32_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeInt32Value(dummy_val)); + } + { + int64_t dummy_val{0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeInt64Value(dummy_val)); + } + { + double dummy_val{0.0}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeDoubleValue(dummy_val)); + } + { + std::string dummy_str = "dummy"; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeStringValue(dummy_str)); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, + filter_->encodeMapBegin(dummy_ft, dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeMapEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeListBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeListEnd()); + } + { + ThriftProxy::FieldType dummy_ft = ThriftProxy::FieldType::I32; + uint32_t dummy_size{1}; + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeSetBegin(dummy_ft, dummy_size)); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeSetEnd()); + } + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeStructEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeFieldEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeMessageEnd()); + EXPECT_EQ(ThriftProxy::FilterStatus::Continue, filter_->encodeTransportEnd()); + + EXPECT_NO_THROW(filter_->onDestroy()); +} + } // namespace ThriftFilters } // namespace ThriftProxy } // namespace NetworkFilters diff --git a/test/extensions/filters/network/thrift_proxy/mocks.cc b/test/extensions/filters/network/thrift_proxy/mocks.cc index 0153c42b6103f..86cd6b3decc2a 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.cc +++ b/test/extensions/filters/network/thrift_proxy/mocks.cc @@ -92,6 +92,88 @@ MockDecoderFilterCallbacks::MockDecoderFilterCallbacks() { } MockDecoderFilterCallbacks::~MockDecoderFilterCallbacks() = default; +MockEncoderFilter::MockEncoderFilter() { + ON_CALL(*this, transportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, transportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, messageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, structEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, fieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, boolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, byteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, int64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, doubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, stringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, mapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, listEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, setEnd()).WillByDefault(Return(FilterStatus::Continue)); +} +MockEncoderFilter::~MockEncoderFilter() = default; + +MockEncoderFilterCallbacks::MockEncoderFilterCallbacks() { + route_ = std::make_shared>(); + + ON_CALL(*this, streamId()).WillByDefault(Return(stream_id_)); + ON_CALL(*this, connection()).WillByDefault(Return(&connection_)); + ON_CALL(*this, route()).WillByDefault(Return(route_)); + ON_CALL(*this, streamInfo()).WillByDefault(ReturnRef(stream_info_)); +} +MockEncoderFilterCallbacks::~MockEncoderFilterCallbacks() = default; + +MockBidirectionalFilter::MockBidirectionalFilter() { + ON_CALL(*this, decodeTransportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeTransportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeMessageBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeMessageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeStructBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeStructEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeFieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeFieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeBoolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeByteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeInt16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeInt32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeInt64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeDoubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeStringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeMapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeMapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeListBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeListEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeSetBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, decodeSetEnd()).WillByDefault(Return(FilterStatus::Continue)); + + ON_CALL(*this, encodeTransportBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeTransportEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeMessageBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeMessageEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeStructBegin(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeStructEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeFieldBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeFieldEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeBoolValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeByteValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeInt16Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeInt32Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeInt64Value(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeDoubleValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeStringValue(_)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeMapBegin(_, _, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeMapEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeListBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeListEnd()).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeSetBegin(_, _)).WillByDefault(Return(FilterStatus::Continue)); + ON_CALL(*this, encodeSetEnd()).WillByDefault(Return(FilterStatus::Continue)); +} +MockBidirectionalFilter::~MockBidirectionalFilter() = default; + MockFilterConfigFactory::MockFilterConfigFactory() : name_("envoy.filters.thrift.mock_filter") { mock_filter_ = std::make_shared>(); } diff --git a/test/extensions/filters/network/thrift_proxy/mocks.h b/test/extensions/filters/network/thrift_proxy/mocks.h index 0fd9f8cc54558..4fb758e28d6fb 100644 --- a/test/extensions/filters/network/thrift_proxy/mocks.h +++ b/test/extensions/filters/network/thrift_proxy/mocks.h @@ -201,6 +201,8 @@ class MockFilterChainFactoryCallbacks : public FilterChainFactoryCallbacks { ~MockFilterChainFactoryCallbacks() override; MOCK_METHOD(void, addDecoderFilter, (DecoderFilterSharedPtr)); + MOCK_METHOD(void, addEncoderFilter, (EncoderFilterSharedPtr)); + MOCK_METHOD(void, addBidirectionalFilter, (BidirectionalFilterSharedPtr)); }; class MockDecoderFilter : public DecoderFilter { @@ -269,6 +271,132 @@ class MockDecoderFilterCallbacks : public DecoderFilterCallbacks { std::shared_ptr route_; }; +class MockEncoderFilter : public EncoderFilter { +public: + MockEncoderFilter(); + ~MockEncoderFilter() override; + + // ThriftProxy::ThriftFilters::EncoderFilter + MOCK_METHOD(void, onDestroy, ()); + MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks)); + MOCK_METHOD(bool, passthroughSupported, (), (const)); + + // ThriftProxy::DecoderEventHandler + MOCK_METHOD(FilterStatus, passthroughData, (Buffer::Instance & data)); + MOCK_METHOD(FilterStatus, transportBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, transportEnd, ()); + MOCK_METHOD(FilterStatus, messageBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, messageEnd, ()); + MOCK_METHOD(FilterStatus, structBegin, (absl::string_view name)); + MOCK_METHOD(FilterStatus, structEnd, ()); + MOCK_METHOD(FilterStatus, fieldBegin, + (absl::string_view name, FieldType& msg_type, int16_t& field_id)); + MOCK_METHOD(FilterStatus, fieldEnd, ()); + MOCK_METHOD(FilterStatus, boolValue, (bool& value)); + MOCK_METHOD(FilterStatus, byteValue, (uint8_t & value)); + MOCK_METHOD(FilterStatus, int16Value, (int16_t & value)); + MOCK_METHOD(FilterStatus, int32Value, (int32_t & value)); + MOCK_METHOD(FilterStatus, int64Value, (int64_t & value)); + MOCK_METHOD(FilterStatus, doubleValue, (double& value)); + MOCK_METHOD(FilterStatus, stringValue, (absl::string_view value)); + MOCK_METHOD(FilterStatus, mapBegin, + (FieldType & key_type, FieldType& value_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, mapEnd, ()); + MOCK_METHOD(FilterStatus, listBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, listEnd, ()); + MOCK_METHOD(FilterStatus, setBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, setEnd, ()); +}; + +class MockEncoderFilterCallbacks : public EncoderFilterCallbacks { +public: + MockEncoderFilterCallbacks(); + ~MockEncoderFilterCallbacks() override; + + // ThriftProxy::ThriftFilters::EncoderFilterCallbacks + MOCK_METHOD(uint64_t, streamId, (), (const)); + MOCK_METHOD(const Network::Connection*, connection, (), (const)); + MOCK_METHOD(Event::Dispatcher&, dispatcher, ()); + MOCK_METHOD(void, continueEncoding, ()); + MOCK_METHOD(Router::RouteConstSharedPtr, route, ()); + MOCK_METHOD(TransportType, downstreamTransportType, (), (const)); + MOCK_METHOD(ProtocolType, downstreamProtocolType, (), (const)); + MOCK_METHOD(void, resetDownstreamConnection, ()); + MOCK_METHOD(StreamInfo::StreamInfo&, streamInfo, ()); + MOCK_METHOD(MessageMetadataSharedPtr, responseMetadata, ()); + MOCK_METHOD(bool, responseSuccess, ()); + MOCK_METHOD(void, onReset, ()); + + uint64_t stream_id_{1}; + NiceMock connection_; + NiceMock stream_info_; + MessageMetadataSharedPtr metadata_; + std::shared_ptr route_; +}; + +class MockBidirectionalFilter : public BidirectionalFilter { +public: + MockBidirectionalFilter(); + ~MockBidirectionalFilter() override; + + // ThriftProxy::ThriftFilters::BidirectionalFilter + MOCK_METHOD(void, onDestroy, ()); + MOCK_METHOD(void, setEncoderFilterCallbacks, (EncoderFilterCallbacks & callbacks)); + MOCK_METHOD(bool, encodePassthroughSupported, (), (const)); + MOCK_METHOD(void, setDecoderFilterCallbacks, (DecoderFilterCallbacks & callbacks)); + MOCK_METHOD(bool, decodePassthroughSupported, (), (const)); + + MOCK_METHOD(FilterStatus, encodePassthroughData, (Buffer::Instance & data)); + MOCK_METHOD(FilterStatus, encodeTransportBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, encodeTransportEnd, ()); + MOCK_METHOD(FilterStatus, encodeMessageBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, encodeMessageEnd, ()); + MOCK_METHOD(FilterStatus, encodeStructBegin, (absl::string_view name)); + MOCK_METHOD(FilterStatus, encodeStructEnd, ()); + MOCK_METHOD(FilterStatus, encodeFieldBegin, + (absl::string_view name, FieldType& msg_type, int16_t& field_id)); + MOCK_METHOD(FilterStatus, encodeFieldEnd, ()); + MOCK_METHOD(FilterStatus, encodeBoolValue, (bool& value)); + MOCK_METHOD(FilterStatus, encodeByteValue, (uint8_t & value)); + MOCK_METHOD(FilterStatus, encodeInt16Value, (int16_t & value)); + MOCK_METHOD(FilterStatus, encodeInt32Value, (int32_t & value)); + MOCK_METHOD(FilterStatus, encodeInt64Value, (int64_t & value)); + MOCK_METHOD(FilterStatus, encodeDoubleValue, (double& value)); + MOCK_METHOD(FilterStatus, encodeStringValue, (absl::string_view value)); + MOCK_METHOD(FilterStatus, encodeMapBegin, + (FieldType & key_type, FieldType& value_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, encodeMapEnd, ()); + MOCK_METHOD(FilterStatus, encodeListBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, encodeListEnd, ()); + MOCK_METHOD(FilterStatus, encodeSetBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, encodeSetEnd, ()); + + MOCK_METHOD(FilterStatus, decodePassthroughData, (Buffer::Instance & data)); + MOCK_METHOD(FilterStatus, decodeTransportBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, decodeTransportEnd, ()); + MOCK_METHOD(FilterStatus, decodeMessageBegin, (MessageMetadataSharedPtr metadata)); + MOCK_METHOD(FilterStatus, decodeMessageEnd, ()); + MOCK_METHOD(FilterStatus, decodeStructBegin, (absl::string_view name)); + MOCK_METHOD(FilterStatus, decodeStructEnd, ()); + MOCK_METHOD(FilterStatus, decodeFieldBegin, + (absl::string_view name, FieldType& msg_type, int16_t& field_id)); + MOCK_METHOD(FilterStatus, decodeFieldEnd, ()); + MOCK_METHOD(FilterStatus, decodeBoolValue, (bool& value)); + MOCK_METHOD(FilterStatus, decodeByteValue, (uint8_t & value)); + MOCK_METHOD(FilterStatus, decodeInt16Value, (int16_t & value)); + MOCK_METHOD(FilterStatus, decodeInt32Value, (int32_t & value)); + MOCK_METHOD(FilterStatus, decodeInt64Value, (int64_t & value)); + MOCK_METHOD(FilterStatus, decodeDoubleValue, (double& value)); + MOCK_METHOD(FilterStatus, decodeStringValue, (absl::string_view value)); + MOCK_METHOD(FilterStatus, decodeMapBegin, + (FieldType & key_type, FieldType& value_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, decodeMapEnd, ()); + MOCK_METHOD(FilterStatus, decodeListBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, decodeListEnd, ()); + MOCK_METHOD(FilterStatus, decodeSetBegin, (FieldType & elem_type, uint32_t& size)); + MOCK_METHOD(FilterStatus, decodeSetEnd, ()); +}; + class MockFilterConfigFactory : public NamedThriftFilterConfigFactory { public: MockFilterConfigFactory(); diff --git a/tools/spelling/spelling_dictionary.txt b/tools/spelling/spelling_dictionary.txt index 6422f6a975d06..4ee3394bd3abb 100644 --- a/tools/spelling/spelling_dictionary.txt +++ b/tools/spelling/spelling_dictionary.txt @@ -23,6 +23,7 @@ AWS BACKTRACE BEL BBR +BIDIRECTIONAL BSON BPF Repick