diff --git a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto index 1a7176dc33031..97c905eafc608 100644 --- a/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto +++ b/api/envoy/config/filter/network/thrift_proxy/v2alpha1/thrift_proxy.proto @@ -23,6 +23,9 @@ message ThriftProxy { // The Thrift proxy will assume the client is using the Thrift unframed transport. UNFRAMED = 2; + + // The Thrift proxy will assume the client is using the Thrift header transport. + HEADER = 3; } // Supplies the type of transport that the Thrift proxy should use. Defaults to `AUTO_TRANSPORT`. diff --git a/source/extensions/filters/network/thrift_proxy/BUILD b/source/extensions/filters/network/thrift_proxy/BUILD index e75efa1c07609..011949996723b 100644 --- a/source/extensions/filters/network/thrift_proxy/BUILD +++ b/source/extensions/filters/network/thrift_proxy/BUILD @@ -187,11 +187,13 @@ envoy_cc_library( name = "transport_lib", srcs = [ "framed_transport_impl.cc", + "header_transport_impl.cc", "transport_impl.cc", "unframed_transport_impl.cc", ], hdrs = [ "framed_transport_impl.h", + "header_transport_impl.h", "transport_impl.h", "unframed_transport_impl.h", ], diff --git a/source/extensions/filters/network/thrift_proxy/config.cc b/source/extensions/filters/network/thrift_proxy/config.cc index 22ce4fbf3e163..54134537a056f 100644 --- a/source/extensions/filters/network/thrift_proxy/config.cc +++ b/source/extensions/filters/network/thrift_proxy/config.cc @@ -41,6 +41,8 @@ static const TransportTypeMap& transportTypeMap() { {envoy::config::filter::network::thrift_proxy::v2alpha1:: ThriftProxy_TransportType_UNFRAMED, TransportType::Unframed}, + {envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType_HEADER, + TransportType::Header}, }); } @@ -92,6 +94,20 @@ ConfigImpl::ConfigImpl( transport_(config.transport()), proto_(config.protocol()), route_matcher_(new Router::RouteMatcher(config.route_config())) { + if (transportTypeMap().find(transport_) == transportTypeMap().end()) { + throw EnvoyException(fmt::format( + "unknown transport {}", + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType_Name( + transport_))); + } + + if (protocolTypeMap().find(proto_) == protocolTypeMap().end()) { + throw EnvoyException(fmt::format( + "unknown protocol {}", + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType_Name( + proto_))); + } + // Construct the only Thrift DecoderFilter: the Router auto& factory = Envoy::Config::Utility::getAndCheckFactory( @@ -115,14 +131,14 @@ DecoderPtr ConfigImpl::createDecoder(DecoderCallbacks& callbacks) { TransportPtr ConfigImpl::createTransport() { TransportTypeMap::const_iterator i = transportTypeMap().find(transport_); - RELEASE_ASSERT(i != transportTypeMap().end(), "invalid transport type"); + ASSERT(i != transportTypeMap().end()); return NamedTransportConfigFactory::getFactory(i->second).createTransport(); } ProtocolPtr ConfigImpl::createProtocol() { ProtocolTypeMap::const_iterator i = protocolTypeMap().find(proto_); - RELEASE_ASSERT(i != protocolTypeMap().end(), "invalid protocol type"); + ASSERT(i != protocolTypeMap().end()); return NamedProtocolConfigFactory::getFactory(i->second).createProtocol(); } diff --git a/source/extensions/filters/network/thrift_proxy/decoder.cc b/source/extensions/filters/network/thrift_proxy/decoder.cc index d3d69b8e09792..f2219d68acb50 100644 --- a/source/extensions/filters/network/thrift_proxy/decoder.cc +++ b/source/extensions/filters/network/thrift_proxy/decoder.cc @@ -385,6 +385,25 @@ ThriftFilters::FilterStatus Decoder::onData(Buffer::Instance& data, bool& buffer } ENVOY_LOG(debug, "thrift: {} transport started", transport_->name()); + if (metadata_->hasProtocol()) { + if (protocol_->type() == ProtocolType::Auto) { + protocol_->setType(metadata_->protocol()); + ENVOY_LOG(debug, "thrift: {} transport forced {} protocol", transport_->name(), + protocol_->name()); + } else if (metadata_->protocol() != protocol_->type()) { + throw EnvoyException(fmt::format("transport reports protocol {}, but configured for {}", + ProtocolNames::get().fromType(metadata_->protocol()), + ProtocolNames::get().fromType(protocol_->type()))); + } + } + if (metadata_->hasAppException()) { + AppExceptionType ex_type = metadata_->appExceptionType(); + std::string ex_msg = metadata_->appExceptionMessage(); + // Force new metadata if we get called again. + metadata_.reset(); + throw AppException(ex_type, ex_msg); + } + request_ = std::make_unique(callbacks_.newDecoderFilter()); frame_started_ = true; state_machine_ = diff --git a/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc new file mode 100644 index 0000000000000..0dee09aed8c0b --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/header_transport_impl.cc @@ -0,0 +1,316 @@ +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" + +#include + +#include "envoy/common/exception.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/buffer_helper.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { +namespace { + +// c.f. +// https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/protocol/TProtocolTypes.h#L27 +enum class HeaderProtocolType { + Binary = 0, + JSON = 1, + Compact = 2, + + FirstHeaderProtocolType = Binary, + LastHeaderProtocolType = Compact, +}; + +// Fixed portion of frame header: +// Header magic: 2 bytes + +// Flags: 2 bytes + +// Sequence number: 4 bytes +// Header data size: 2 bytes +constexpr uint64_t MinFrameStartSizeNoHeaders = 10; + +// Minimum frame size: fixed portion of frame header + 4 bytes of header data (the minimum) +constexpr int32_t MinFrameStartSize = MinFrameStartSizeNoHeaders + 4; + +// Minimum to start decoding: 4 bytes of frame size + the fixed portion of the frame header +constexpr uint64_t MinDecodeBytes = MinFrameStartSizeNoHeaders + 4; + +// Maximum size for header data. +constexpr int32_t MaxHeadersSize = 65536; + +} // namespace + +bool HeaderTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) { + if (buffer.length() < MinDecodeBytes) { + return false; + } + + // Size of frame, not including the length bytes. + const int32_t frame_size = BufferHelper::peekI32(buffer); + + // Minimum header frame size is 18 bytes (4 bytes of frame size + 10 bytes of fixed header + + // minimum 4 bytes of variable header data), so frame_size must be at least 14. + if (frame_size < MinFrameStartSize || frame_size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift header transport frame size {}", frame_size)); + } + + int16_t magic = BufferHelper::peekU16(buffer, 4); + if (!isMagic(magic)) { + throw EnvoyException(fmt::format("invalid thrift header transport magic {:04x}", magic)); + } + + // offset 6: 16 bit flags field, unused + // offset 8: 32 bit sequence number field + int32_t seq_id = BufferHelper::peekI32(buffer, 8); + + // offset 12: 16 bit (remaining) header size / 4 (spec erroneously claims / 32). + int16_t raw_header_size = BufferHelper::peekI16(buffer, 12); + int32_t header_size = static_cast(raw_header_size) * 4; + if (header_size < 0 || header_size > MaxHeadersSize) { + throw EnvoyException(fmt::format("invalid thrift header transport header size {} ({:04x})", + header_size, static_cast(raw_header_size))); + } + + if (header_size == 0) { + throw EnvoyException("no header data"); + } + + if (buffer.length() < static_cast(header_size) + MinDecodeBytes) { + // Need more header data. + return false; + } + + // Header data starts at offset 14 (4 bytes of frame size followed by 10 bytes of fixed header). + buffer.drain(MinDecodeBytes); + + // Remaining frame size is the original frame size (which does not count itself), less the 10 + // fixed bytes of the header (magic, flags, etc), less the size of the variable header data + // (header_size). + metadata.setFrameSize( + static_cast(frame_size - header_size - MinFrameStartSizeNoHeaders)); + metadata.setSequenceId(seq_id); + + ProtocolType proto = ProtocolType::Auto; + HeaderProtocolType header_proto = + static_cast(drainVarIntI16(buffer, header_size, "protocol id")); + switch (header_proto) { + case HeaderProtocolType::Binary: + proto = ProtocolType::Binary; + break; + case HeaderProtocolType::Compact: + proto = ProtocolType::Compact; + break; + default: + throw EnvoyException(fmt::format("Unknown protocol {}", static_cast(header_proto))); + } + metadata.setProtocol(proto); + + int16_t num_xforms = drainVarIntI16(buffer, header_size, "transform count"); + if (num_xforms < 0) { + throw EnvoyException(fmt::format("invalid header transport transform count {}", num_xforms)); + } + + while (num_xforms-- > 0) { + int32_t xform_id = drainVarIntI32(buffer, header_size, "transform id"); + + // To date, no transforms have a data field. In the future, some transform IDs may require + // consuming another varint 32 at this point. The known transform IDs are: + // 1: zlib compression + // 2: hmac (appended to end of packet) + // 3: snappy compression + buffer.drain(header_size); + metadata.setAppException(AppExceptionType::MissingResult, + fmt::format("Unknown transform {}", xform_id)); + return true; + } + + while (header_size > 0) { + // Attempt to read info blocks + int32_t info_id = drainVarIntI32(buffer, header_size, "info id"); + if (info_id != 1) { + // 0 indicates a padding byte, and the end of the info block. + // 1 indicates an info id header/value pair. + // Any other value is an unknown info id block, which we ignore. + break; + } + + int32_t num_headers = drainVarIntI32(buffer, header_size, "header count"); + if (num_headers < 0) { + throw EnvoyException(fmt::format("invalid header transport header count {}", num_headers)); + } + + while (num_headers-- > 0) { + std::string key = drainVarString(buffer, header_size, "header key"); + std::string value = drainVarString(buffer, header_size, "header value"); + metadata.addHeader(Header(key, value)); + } + } + + // Remaining bytes are padding or ignored info blocks. + if (header_size > 0) { + buffer.drain(header_size); + } + + return true; +} + +bool HeaderTransportImpl::decodeFrameEnd(Buffer::Instance&) { + exception_.reset(); + exception_reason_.clear(); + + return true; +} + +void HeaderTransportImpl::encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) { + uint64_t msg_size = message.length(); + if (msg_size == 0) { + throw EnvoyException(fmt::format("invalid thrift header transport message size {}", msg_size)); + } + + const HeaderMap& headers = metadata.headers(); + if (headers.size() > MaxHeadersSize / 2) { + // Each header takes a minimum of 2 bytes, yielding this limit. + throw EnvoyException( + fmt::format("invalid thrift header transport too many headers {}", headers.size())); + } + + Buffer::OwnedImpl header_buffer; + + if (!metadata.hasProtocol()) { + throw EnvoyException("missing header transport protocol"); + } + + switch (metadata.protocol()) { + case ProtocolType::Binary: + BufferHelper::writeVarIntI32(header_buffer, static_cast(HeaderProtocolType::Binary)); + break; + case ProtocolType::Compact: + BufferHelper::writeVarIntI32(header_buffer, static_cast(HeaderProtocolType::Compact)); + break; + default: + throw EnvoyException(fmt::format("invalid header transport protocol {}", + ProtocolNames::get().fromType(metadata.protocol()))); + } + + BufferHelper::writeVarIntI32(header_buffer, 0); // num transforms + if (headers.size() > 0) { + // Info ID 1 + BufferHelper::writeI8(header_buffer, 1); + + // Num headers + BufferHelper::writeVarIntI32(header_buffer, static_cast(headers.size())); + + for (const Header& header : headers) { + writeVarString(header_buffer, header.key()); + writeVarString(header_buffer, header.value()); + } + } + + uint64_t header_size = header_buffer.length(); + + // Always pad (as the Apache implementation does). + int padding = 4 - (header_size % 4); + header_buffer.add("\0\0\0\0", padding); + header_size += padding; + + if (header_size > MaxHeadersSize) { + throw EnvoyException( + fmt::format("invalid thrift header transport header size {}", header_size)); + } + + // Frame size does not include the frame length itself. + uint64_t size = header_size + msg_size + MinFrameStartSizeNoHeaders; + if (size > MaxFrameSize) { + throw EnvoyException(fmt::format("invalid thrift header transport frame size {}", size)); + } + + int32_t seq_id = 0; + if (metadata.hasSequenceId()) { + seq_id = metadata.sequenceId(); + } + + BufferHelper::writeU32(buffer, static_cast(size)); + BufferHelper::writeU16(buffer, Magic); + BufferHelper::writeU16(buffer, 0); // flags + BufferHelper::writeI32(buffer, seq_id); + BufferHelper::writeU16(buffer, static_cast(header_size / 4)); + buffer.move(header_buffer); + buffer.move(message); +} + +int16_t HeaderTransportImpl::drainVarIntI16(Buffer::Instance& buffer, int32_t& header_size, + const char* desc) { + int32_t value = drainVarIntI32(buffer, header_size, desc); + if (value > static_cast(std::numeric_limits::max())) { + throw EnvoyException(fmt::format("header transport {}: value {} exceeds max i16 ({})", desc, + value, std::numeric_limits::max())); + } + return static_cast(value); +} + +int32_t HeaderTransportImpl::drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size, + const char* desc) { + if (header_size <= 0) { + throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); + } + + int size; + int32_t value = BufferHelper::peekVarIntI32(buffer, 0, size); + if (size < 0 || (header_size - size) < 0) { + throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); + } + buffer.drain(size); + header_size -= size; + return value; +} + +std::string HeaderTransportImpl::drainVarString(Buffer::Instance& buffer, int32_t& header_size, + const char* desc) { + int16_t str_len = drainVarIntI16(buffer, header_size, desc); + if (str_len == 0) { + return ""; + } + + if (header_size < static_cast(str_len)) { + throw EnvoyException(fmt::format("unable to read header transport {}: header too small", desc)); + } + + std::string value(static_cast(buffer.linearize(str_len)), str_len); + buffer.drain(str_len); + header_size -= str_len; + return value; +} + +void HeaderTransportImpl::writeVarString(Buffer::Instance& buffer, const std::string& str) { + std::string::size_type len = str.length(); + if (len > static_cast(std::numeric_limits::max())) { + throw EnvoyException(fmt::format("header string too long: {}", len)); + } + + BufferHelper::writeVarIntI32(buffer, static_cast(len)); + if (len == 0) { + return; + } + buffer.add(str); +} + +class HeaderTransportConfigFactory : public TransportFactoryBase { +public: + HeaderTransportConfigFactory() : TransportFactoryBase(TransportNames::get().HEADER) {} +}; + +/** + * Static registration for the header transport. @see RegisterFactory. + */ +static Registry::RegisterFactory + register_; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/header_transport_impl.h b/source/extensions/filters/network/thrift_proxy/header_transport_impl.h new file mode 100644 index 0000000000000..88807904018f9 --- /dev/null +++ b/source/extensions/filters/network/thrift_proxy/header_transport_impl.h @@ -0,0 +1,64 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" + +#include "extensions/filters/network/thrift_proxy/app_exception_impl.h" +#include "extensions/filters/network/thrift_proxy/metadata.h" +#include "extensions/filters/network/thrift_proxy/protocol.h" +#include "extensions/filters/network/thrift_proxy/transport_impl.h" + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +/** + * HeaderTransportImpl implements the Thrift Header transport. + * See https://github.com/apache/thrift/blob/master/doc/specs/HeaderFormat.md and + * https://github.com/apache/thrift/blob/master/lib/cpp/src/thrift/transport/THeaderTransport.h + * (for constants not specified in the spec). + */ +class HeaderTransportImpl : public Transport { +public: + // Transport + const std::string& name() const override { return TransportNames::get().HEADER; } + TransportType type() const override { return TransportType::Header; } + bool decodeFrameStart(Buffer::Instance& buffer, MessageMetadata& metadata) override; + bool decodeFrameEnd(Buffer::Instance& buffer) override; + void encodeFrame(Buffer::Instance& buffer, const MessageMetadata& metadata, + Buffer::Instance& message) override; + + static bool isMagic(uint16_t word) { return word == Magic; } + + static constexpr int32_t MaxFrameSize = 0x3FFFFFFF; + +private: + static constexpr uint16_t Magic = 0x0FFF; + + static int16_t drainVarIntI16(Buffer::Instance& buffer, int32_t& header_size, const char* desc); + static int32_t drainVarIntI32(Buffer::Instance& buffer, int32_t& header_size, const char* desc); + static std::string drainVarString(Buffer::Instance& buffer, int32_t& header_size, + const char* desc); + static void writeVarString(Buffer::Instance& buffer, const std::string& str); + + void setException(AppExceptionType type, std::string reason) { + if (exception_.has_value()) { + return; + } + + exception_ = type; + exception_reason_ = reason; + } + + absl::optional exception_; + std::string exception_reason_; +}; + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/thrift_proxy/protocol.h b/source/extensions/filters/network/thrift_proxy/protocol.h index 7d5baa3b9d9f6..0e4066021ce9b 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol.h +++ b/source/extensions/filters/network/thrift_proxy/protocol.h @@ -39,6 +39,13 @@ class Protocol { */ virtual ProtocolType type() const PURE; + /** + * For protocol-detecting implementations, set the underlying type based on external + * (e.g. transport-level) information). + * @param type ProtocolType to explicitly set + */ + virtual void setType(ProtocolType) { NOT_IMPLEMENTED_GCOVR_EXCL_LINE; } + /** * Reads the start of a Thrift protocol message from the buffer and updates the metadata * parameter with values from the message header. If successful, the message header is removed diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc index d0780b20d9591..4d1cb437252aa 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.cc @@ -17,6 +17,22 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { +void AutoProtocolImpl::setType(ProtocolType type) { + if (!protocol_) { + switch (type) { + case ProtocolType::Binary: + setProtocol(std::make_unique()); + break; + case ProtocolType::Compact: + setProtocol(std::make_unique()); + break; + default: + // Ignored: attempt protocol detection. + break; + } + } +} + bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) { if (protocol_ == nullptr) { if (buffer.length() < 2) { @@ -25,9 +41,9 @@ bool AutoProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadat uint16_t version = BufferHelper::peekU16(buffer); if (BinaryProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique()); + setType(ProtocolType::Binary); } else if (CompactProtocolImpl::isMagic(version)) { - setProtocol(std::make_unique()); + setType(ProtocolType::Compact); } if (!protocol_) { diff --git a/source/extensions/filters/network/thrift_proxy/protocol_impl.h b/source/extensions/filters/network/thrift_proxy/protocol_impl.h index 227941ca5811a..1f720f458cce9 100644 --- a/source/extensions/filters/network/thrift_proxy/protocol_impl.h +++ b/source/extensions/filters/network/thrift_proxy/protocol_impl.h @@ -30,6 +30,7 @@ class AutoProtocolImpl : public Protocol { } return ProtocolType::Auto; } + void setType(ProtocolType type) override; bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override; bool readMessageEnd(Buffer::Instance& buffer) override; diff --git a/source/extensions/filters/network/thrift_proxy/thrift.h b/source/extensions/filters/network/thrift_proxy/thrift.h index fe3d7022a59e4..3fa29944fd7b3 100644 --- a/source/extensions/filters/network/thrift_proxy/thrift.h +++ b/source/extensions/filters/network/thrift_proxy/thrift.h @@ -10,6 +10,7 @@ namespace ThriftProxy { enum class TransportType { Framed, + Header, Unframed, Auto, @@ -26,6 +27,9 @@ class TransportNameValues { // Framed transport const std::string FRAMED = "framed"; + // Header transport + const std::string HEADER = "header"; + // Unframed transport const std::string UNFRAMED = "unframed"; @@ -36,6 +40,8 @@ class TransportNameValues { switch (type) { case TransportType::Framed: return FRAMED; + case TransportType::Header: + return HEADER; case TransportType::Unframed: return UNFRAMED; case TransportType::Auto: diff --git a/source/extensions/filters/network/thrift_proxy/transport_impl.cc b/source/extensions/filters/network/thrift_proxy/transport_impl.cc index 0d5971f03e7e3..d7efc60242397 100644 --- a/source/extensions/filters/network/thrift_proxy/transport_impl.cc +++ b/source/extensions/filters/network/thrift_proxy/transport_impl.cc @@ -8,6 +8,7 @@ #include "extensions/filters/network/thrift_proxy/buffer_helper.h" #include "extensions/filters/network/thrift_proxy/compact_protocol_impl.h" #include "extensions/filters/network/thrift_proxy/framed_transport_impl.h" +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" #include "extensions/filters/network/thrift_proxy/unframed_transport_impl.h" namespace Envoy { @@ -25,7 +26,16 @@ bool AutoTransportImpl::decodeFrameStart(Buffer::Instance& buffer, MessageMetada int32_t size = BufferHelper::peekI32(buffer); uint16_t proto_start = BufferHelper::peekU16(buffer, 4); - if (size > 0 && size <= FramedTransportImpl::MaxFrameSize) { + // Currently, transport detection depends on the following: + // 1. Protocol may only be binary or compact, which start with 0x8001 or 0x8201. + // 2. If unframed transport, size will appear negative due to leading protocol bytes. + // 3. If header transport, size is followed by 0x0FFF which is distinct from leading + // protocol bytes. + // 4. For framed transport, size is followed by protocol bytes. + if (size > 0 && size <= HeaderTransportImpl::MaxFrameSize && + HeaderTransportImpl::isMagic(proto_start)) { + setTransport(std::make_unique()); + } else if (size > 0 && size <= FramedTransportImpl::MaxFrameSize) { // TODO(zuercher): Spec says max size is 16,384,000 (0xFA0000). Apache C++ TFramedTransport // is configurable, but defaults to 256 MB (0x1000000). if (BinaryProtocolImpl::isMagic(proto_start) || CompactProtocolImpl::isMagic(proto_start)) { diff --git a/test/extensions/filters/network/thrift_proxy/BUILD b/test/extensions/filters/network/thrift_proxy/BUILD index e07767ba0963f..ead7baef2eada 100644 --- a/test/extensions/filters/network/thrift_proxy/BUILD +++ b/test/extensions/filters/network/thrift_proxy/BUILD @@ -147,6 +147,20 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "header_transport_impl_test", + srcs = ["header_transport_impl_test.cc"], + extension_name = "envoy.filters.network.thrift_proxy", + deps = [ + ":mocks", + ":utility_lib", + "//source/extensions/filters/network/thrift_proxy:transport_lib", + "//test/mocks/buffer:buffer_mocks", + "//test/test_common:printers_lib", + "//test/test_common:utility_lib", + ], +) + envoy_extension_cc_test( name = "metadata_test", srcs = ["metadata_test.cc"], diff --git a/test/extensions/filters/network/thrift_proxy/config_test.cc b/test/extensions/filters/network/thrift_proxy/config_test.cc index 220ce5c3fddf6..a1433cd6d5547 100644 --- a/test/extensions/filters/network/thrift_proxy/config_test.cc +++ b/test/extensions/filters/network/thrift_proxy/config_test.cc @@ -14,38 +14,107 @@ namespace Extensions { namespace NetworkFilters { namespace ThriftProxy { -TEST(ThriftFilterConfigTest, ValidateFail) { - NiceMock context; - EXPECT_THROW(ThriftProxyFilterConfigFactory().createFilterFactoryFromProto( - envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy(), context), +namespace { + +std::vector +getTransportTypes() { + std::vector v; + int transport = envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_TransportType_MIN; + while (transport <= envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_TransportType_TransportType_MAX) { + v.push_back(static_cast< + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType>( + transport)); + transport++; + } + return v; +} + +std::vector +getProtocolTypes() { + std::vector v; + int protocol = envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_ProtocolType_ProtocolType_MIN; + while (protocol <= envoy::config::filter::network::thrift_proxy::v2alpha1:: + ThriftProxy_ProtocolType_ProtocolType_MAX) { + v.push_back(static_cast< + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType>( + protocol)); + protocol++; + } + return v; +} + +} // namespace + +class ThriftFilterConfigTestBase { +public: + void testConfig(envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy& config) { + Network::FilterFactoryCb cb; + EXPECT_NO_THROW({ cb = factory_.createFilterFactoryFromProto(config, context_); }); + + Network::MockConnection connection; + EXPECT_CALL(connection, addReadFilter(_)); + cb(connection); + } + + NiceMock context_; + ThriftProxyFilterConfigFactory factory_; +}; + +class ThriftFilterConfigTest : public ThriftFilterConfigTestBase, public testing::Test {}; + +class ThriftFilterTransportConfigTest + : public ThriftFilterConfigTestBase, + public testing::TestWithParam< + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_TransportType> {}; + +INSTANTIATE_TEST_CASE_P(TransportTypes, ThriftFilterTransportConfigTest, + testing::ValuesIn(getTransportTypes())); + +class ThriftFilterProtocolConfigTest + : public ThriftFilterConfigTestBase, + public testing::TestWithParam< + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy_ProtocolType> {}; + +INSTANTIATE_TEST_CASE_P(ProtocolTypes, ThriftFilterProtocolConfigTest, + testing::ValuesIn(getProtocolTypes())); + +TEST_F(ThriftFilterConfigTest, ValidateFail) { + EXPECT_THROW(factory_.createFilterFactoryFromProto( + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy(), context_), ProtoValidationException); } -TEST(ThriftFilterConfigTest, ValidProtoConfiguration) { +TEST_F(ThriftFilterConfigTest, ValidProtoConfiguration) { envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config{}; + config.set_stat_prefix("my_stat_prefix"); + + testConfig(config); +} +TEST_P(ThriftFilterTransportConfigTest, ValidProtoConfiguration) { + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config{}; config.set_stat_prefix("my_stat_prefix"); + config.set_transport(GetParam()); + testConfig(config); +} - NiceMock context; - ThriftProxyFilterConfigFactory factory; - Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); - Network::MockConnection connection; - EXPECT_CALL(connection, addReadFilter(_)); - cb(connection); +TEST_P(ThriftFilterProtocolConfigTest, ValidProtoConfiguration) { + envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config{}; + config.set_stat_prefix("my_stat_prefix"); + config.set_protocol(GetParam()); + testConfig(config); } -TEST(ThriftFilterConfigTest, ThriftProxyWithEmptyProto) { - NiceMock context; - ThriftProxyFilterConfigFactory factory; +TEST_F(ThriftFilterConfigTest, ThriftProxyWithEmptyProto) { envoy::config::filter::network::thrift_proxy::v2alpha1::ThriftProxy config = *dynamic_cast( - factory.createEmptyConfigProto().get()); + factory_.createEmptyConfigProto().get()); config.set_stat_prefix("my_stat_prefix"); - Network::FilterFactoryCb cb = factory.createFilterFactoryFromProto(config, context); - Network::MockConnection connection; - EXPECT_CALL(connection, addReadFilter(_)); - cb(connection); + testConfig(config); } } // namespace ThriftProxy diff --git a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc index 790931376f3d0..fe55b350706e0 100644 --- a/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc +++ b/test/extensions/filters/network/thrift_proxy/conn_manager_test.cc @@ -426,6 +426,48 @@ TEST_F(ThriftConnectionManagerTest, OnDataHandlesProtocolErrorDuringMessageBegin EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); } +TEST_F(ThriftConnectionManagerTest, OnDataHandlesTransportApplicationException) { + initializeFilter(); + addSeq(buffer_, { + 0x00, 0x00, 0x00, 0x64, // header: 100 bytes + 0x0f, 0xff, 0x00, 0x00, // magic, flags + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x00, 0x01, 0x00, 0x02, // header size 4, binary proto, 2 transforms + 0x01, 0x02, 0x00, 0x00, // transforms: 1, 2; padding + }); + + std::string err = "Unknown transform 1"; + uint8_t len = 41 + err.length(); + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, len, // header frame size + 0x0f, 0xff, 0x00, 0x00, // magic, flags + 0x00, 0x00, 0x00, 0x00, // sequence id 0 + 0x00, 0x01, 0x00, 0x00, // header size 4, binary, 0 transforms + 0x00, 0x00, // header padding + 0x80, 0x01, 0x00, 0x03, // binary, exception + 0x00, 0x00, 0x00, 0x00, // message name "" + 0x00, 0x00, 0x00, 0x00, // sequence id + 0x0b, 0x00, 0x01, // begin string field + }); + addInt32(write_buffer_, err.length()); + addString(write_buffer_, err); + addSeq(write_buffer_, { + 0x08, 0x00, 0x02, // begin i32 field + 0x00, 0x00, 0x00, 0x05, // missing result + 0x00, // stop field + }); + + EXPECT_CALL(filter_callbacks_.connection_, write(_, false)) + .WillOnce(Invoke([&](Buffer::Instance& buffer, bool) -> void { + EXPECT_EQ(bufferToString(write_buffer_), bufferToString(buffer)); + })); + EXPECT_CALL(filter_callbacks_.connection_, close(Network::ConnectionCloseType::FlushWrite)); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_decoding_error").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); +} + TEST_F(ThriftConnectionManagerTest, OnEvent) { // No active calls { @@ -720,6 +762,46 @@ TEST_F(ThriftConnectionManagerTest, RequestAndResponseProtocolError) { EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); } +TEST_F(ThriftConnectionManagerTest, RequestAndTransportApplicationException) { + initializeFilter(); + writeMessage(buffer_, TransportType::Header, ProtocolType::Binary, MessageType::Call, 0x0F); + + ThriftFilters::DecoderFilterCallbacks* callbacks{}; + EXPECT_CALL(*decoder_filter_, setDecoderFilterCallbacks(_)) + .WillOnce( + Invoke([&](ThriftFilters::DecoderFilterCallbacks& cb) -> void { callbacks = &cb; })); + + EXPECT_EQ(filter_->onData(buffer_, false), Network::FilterStatus::StopIteration); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + + // Response with unknown transform + addSeq(write_buffer_, { + 0x00, 0x00, 0x00, 0x64, // header: 100 bytes + 0x0f, 0xff, 0x00, 0x00, // magic, flags + 0x00, 0x00, 0x00, 0x01, // sequence id + 0x00, 0x01, 0x00, 0x02, // header size 4, binary proto, 2 transforms + 0x01, 0x02, 0x00, 0x00, // transforms: 1, 2; padding + }); + + callbacks->startUpstreamResponse(TransportType::Header, ProtocolType::Binary); + + EXPECT_CALL(filter_callbacks_.connection_.dispatcher_, deferredDelete_(_)).Times(1); + EXPECT_EQ(true, callbacks->upstreamData(write_buffer_)); + + filter_callbacks_.connection_.dispatcher_.clearDeferredDeleteList(); + + EXPECT_EQ(1U, store_.counter("test.request").value()); + EXPECT_EQ(1U, store_.counter("test.request_call").value()); + EXPECT_EQ(0U, store_.gauge("test.request_active").value()); + EXPECT_EQ(0U, store_.counter("test.response").value()); + EXPECT_EQ(0U, store_.counter("test.response_reply").value()); + EXPECT_EQ(0U, store_.counter("test.response_exception").value()); + EXPECT_EQ(0U, store_.counter("test.response_invalid_type").value()); + EXPECT_EQ(0U, store_.counter("test.response_success").value()); + EXPECT_EQ(0U, store_.counter("test.response_error").value()); + EXPECT_EQ(1U, store_.counter("test.response_decoding_error").value()); +} + TEST_F(ThriftConnectionManagerTest, PipelinedRequestAndResponse) { initializeFilter(); writeFramedBinaryMessage(buffer_, MessageType::Call, 0x01); diff --git a/test/extensions/filters/network/thrift_proxy/decoder_test.cc b/test/extensions/filters/network/thrift_proxy/decoder_test.cc index f18b0a7f6e435..2f13f73122572 100644 --- a/test/extensions/filters/network/thrift_proxy/decoder_test.cc +++ b/test/extensions/filters/network/thrift_proxy/decoder_test.cc @@ -781,6 +781,119 @@ TEST(DecoderTest, OnData) { EXPECT_TRUE(underflow); } +TEST(DecoderTest, OnDataWithProtocolHint) { + NiceMock* transport = new NiceMock(); + NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + metadata.setProtocol(ProtocolType::Binary); + return true; + })); + EXPECT_CALL(*proto, type()).WillOnce(Return(ProtocolType::Auto)); + EXPECT_CALL(*proto, setType(ProtocolType::Binary)); + EXPECT_CALL(filter, transportBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasFrameSize()); + EXPECT_EQ(100U, metadata->frameSize()); + + EXPECT_TRUE(metadata->hasProtocol()); + EXPECT_EQ(ProtocolType::Binary, metadata->protocol()); + + return ThriftFilters::FilterStatus::Continue; + })); + + EXPECT_CALL(*proto, readMessageBegin(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setMethodName("name"); + metadata.setMessageType(MessageType::Call); + metadata.setSequenceId(100); + return true; + })); + EXPECT_CALL(filter, messageBegin(_)) + .WillOnce(Invoke([&](MessageMetadataSharedPtr metadata) -> ThriftFilters::FilterStatus { + EXPECT_TRUE(metadata->hasMethodName()); + EXPECT_TRUE(metadata->hasMessageType()); + EXPECT_TRUE(metadata->hasSequenceId()); + EXPECT_EQ("name", metadata->methodName()); + EXPECT_EQ(MessageType::Call, metadata->messageType()); + EXPECT_EQ(100U, metadata->sequenceId()); + return ThriftFilters::FilterStatus::Continue; + })); + + EXPECT_CALL(*proto, readStructBegin(Ref(buffer), _)).WillOnce(Return(true)); + EXPECT_CALL(filter, structBegin(absl::string_view())) + .WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + EXPECT_CALL(*proto, readFieldBegin(Ref(buffer), _, _, _)) + .WillOnce(DoAll(SetArgReferee<2>(FieldType::Stop), Return(true))); + EXPECT_CALL(*proto, readStructEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, structEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + EXPECT_CALL(*proto, readMessageEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, messageEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + EXPECT_CALL(*transport, decodeFrameEnd(Ref(buffer))).WillOnce(Return(true)); + EXPECT_CALL(filter, transportEnd()).WillOnce(Return(ThriftFilters::FilterStatus::Continue)); + + bool underflow = false; + EXPECT_EQ(ThriftFilters::FilterStatus::Continue, decoder.onData(buffer, underflow)); + EXPECT_TRUE(underflow); +} + +TEST(DecoderTest, OnDataWithInconsistentProtocolHint) { + NiceMock* transport = new NiceMock(); + NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setFrameSize(100); + metadata.setProtocol(ProtocolType::Binary); + return true; + })); + EXPECT_CALL(*proto, type()).WillRepeatedly(Return(ProtocolType::Compact)); + + bool underflow = false; + EXPECT_THROW_WITH_MESSAGE(decoder.onData(buffer, underflow), EnvoyException, + "transport reports protocol binary, but configured for compact"); +} + +TEST(DecoderTest, OnDataThrowsTransportAppException) { + NiceMock* transport = new NiceMock(); + NiceMock* proto = new NiceMock(); + NiceMock callbacks; + StrictMock filter; + ON_CALL(callbacks, newDecoderFilter()).WillByDefault(ReturnRef(filter)); + + InSequence dummy; + Decoder decoder(TransportPtr{transport}, ProtocolPtr{proto}, callbacks); + Buffer::OwnedImpl buffer; + + EXPECT_CALL(*transport, decodeFrameStart(Ref(buffer), _)) + .WillOnce(Invoke([&](Buffer::Instance&, MessageMetadata& metadata) -> bool { + metadata.setAppException(AppExceptionType::InvalidTransform, "unknown xform"); + return true; + })); + + bool underflow = false; + EXPECT_THROW_WITH_MESSAGE(decoder.onData(buffer, underflow), AppException, "unknown xform"); +} + TEST(DecoderTest, OnDataResumes) { NiceMock* transport = new NiceMock(); NiceMock* proto = new NiceMock(); diff --git a/test/extensions/filters/network/thrift_proxy/driver/client.py b/test/extensions/filters/network/thrift_proxy/driver/client.py index bbc1293cee55b..c7134ce20ba84 100755 --- a/test/extensions/filters/network/thrift_proxy/driver/client.py +++ b/test/extensions/filters/network/thrift_proxy/driver/client.py @@ -79,6 +79,12 @@ def main(cfg, reqhandle, resphandle): transport, client_type=THeaderTransport.CLIENT_TYPE.HEADER, ) + if cfg.protocol == "binary": + transport.set_protocol_id(THeaderTransport.T_BINARY_PROTOCOL) + elif cfg.protocol == "compact": + transport.set_protocol_id(THeaderTransport.T_COMPACT_PROTOCOL) + else: + sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) else: sys.exit("unknown transport {0}".format(cfg.transport)) diff --git a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py index cba5ec0651d9b..76197ccf6e6e7 100644 --- a/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py +++ b/test/extensions/filters/network/thrift_proxy/driver/fbthrift/THeaderTransport.py @@ -658,5 +658,10 @@ def do_POST(self): # INFO:(zuercher): Added to simplify usage class THeaderTransportFactory: + def __init__(self, proto_id): + self.__proto_id = proto_id + def getTransport(self, trans): - return THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) + header_trans = THeaderTransport(trans, client_type=CLIENT_TYPE.HEADER) + header_trans.set_protocol_id(self.__proto_id) + return header_trans diff --git a/test/extensions/filters/network/thrift_proxy/driver/server.py b/test/extensions/filters/network/thrift_proxy/driver/server.py index 7ab9e862b4e97..e5e9cc3c52bc1 100755 --- a/test/extensions/filters/network/thrift_proxy/driver/server.py +++ b/test/extensions/filters/network/thrift_proxy/driver/server.py @@ -133,7 +133,14 @@ def main(cfg): elif cfg.transport == "unframed": transport_factory = TTransport.TBufferedTransportFactory() elif cfg.transport == "header": - transport_factory = THeaderTransport.THeaderTransportFactory() + if cfg.protocol == "binary": + transport_factory = THeaderTransport.THeaderTransportFactory( + THeaderTransport.T_BINARY_PROTOCOL) + elif cfg.protocol == "compact": + transport_factory = THeaderTransport.THeaderTransportFactory( + THeaderTransport.T_COMPACT_PROTOCOL) + else: + sys.exit("header transport cannot be used with protocol {0}".format(cfg.protocol)) else: sys.exit("unknown transport {0}".format(cfg.transport)) diff --git a/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc new file mode 100644 index 0000000000000..1bcc81120ce58 --- /dev/null +++ b/test/extensions/filters/network/thrift_proxy/header_transport_impl_test.cc @@ -0,0 +1,649 @@ +#include "envoy/common/exception.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/thrift_proxy/header_transport_impl.h" + +#include "test/extensions/filters/network/thrift_proxy/mocks.h" +#include "test/extensions/filters/network/thrift_proxy/utility.h" +#include "test/mocks/buffer/mocks.h" +#include "test/test_common/printers.h" +#include "test/test_common/utility.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::NiceMock; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace ThriftProxy { + +namespace { + +class MockBuffer : public Envoy::MockBuffer { +public: + MockBuffer() {} + ~MockBuffer() {} + + MOCK_CONST_METHOD0(length, uint64_t()); +}; + +MessageMetadata mkMessageMetadata(uint32_t num_headers) { + MessageMetadata metadata; + while (num_headers-- > 0) { + metadata.addHeader(Header("x", "y")); + } + return metadata; +} + +} // namespace + +TEST(HeaderTransportTest, Name) { + HeaderTransportImpl transport; + EXPECT_EQ(transport.name(), "header"); +} + +TEST(HeaderTransportTest, NotEnoughData) { + HeaderTransportImpl transport; + MessageMetadata metadata; + + // Empty buffer + { + Buffer::OwnedImpl buffer; + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + // Too short for minimum header + { + Buffer::OwnedImpl buffer; + addRepeated(buffer, 13, 0); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + // Missing header data + { + Buffer::OwnedImpl buffer; + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // header size / 4 + addRepeated(buffer, 3, 0); + EXPECT_FALSE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } +} + +TEST(HeaderTransportTest, InvalidFrameSize) { + HeaderTransportImpl transport; + MessageMetadata metadata; + + { + Buffer::OwnedImpl buffer; + addInt32(buffer, -1); + addRepeated(buffer, 10, 0); + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid thrift header transport frame size -1"); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + { + Buffer::OwnedImpl buffer; + addInt32(buffer, 0x7fffffff); + addRepeated(buffer, 10, 0); + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid thrift header transport frame size 2147483647"); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } +} + +TEST(HeaderTransportTest, InvalidMagic) { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0123); + addRepeated(buffer, 8, 0); + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid thrift header transport magic 0123"); + EXPECT_THAT(metadata, IsEmptyMetadata()); +} + +TEST(HeaderTransportTest, InvalidHeaderSize) { + HeaderTransportImpl transport; + MessageMetadata metadata; + + // Minimum header size is 1 = 4 bytes + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 0); + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "no header data"); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + // Minimum header size is 1 = 4 bytes + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, -1); + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid thrift header transport header size -4 (ffff)"); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + // Max header size is 16384 = 65536 bytes + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 0x4001); + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid thrift header transport header size 65540 (4001)"); + EXPECT_THAT(metadata, IsEmptyMetadata()); + } + + // Header data extends past stated header size. + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // 4 bytes + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // var int -1, exceeds header size + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "unable to read header transport protocol id: header too small"); + } + + // Partial var-int at end of header + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 0x100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // 4 bytes + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF}); // partial var int + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "unable to read header transport protocol id: header too small"); + } +} + +TEST(HeaderTransportTest, InvalidProto) { + HeaderTransportImpl transport; + MessageMetadata metadata; + + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {1, 0, 0, 0}); // 1 = json, 0 = num transforms, pad, pad + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "Unknown protocol 1"); + } + + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {3, 0, 0, 0}); // 3 = invalid proto, 0 = num transforms, pad, pad + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "Unknown protocol 3"); + } + + { + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 2); // size 8 + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 = invalid proto + addSeq(buffer, {0, 0, 0}); // 0 transforms and padding + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "Unknown protocol -1"); + } +} + +TEST(HeaderTransportTest, NoTransformsOrInfo) { + HeaderTransportImpl transport; + + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {0, 0, 0, 0}); // 0 = binary proto, 0 = num transforms, pad, pad + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(86U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Binary)); + EXPECT_THAT(metadata, HasSequenceId(1)); + EXPECT_THAT(metadata, HasNoHeaders()); + EXPECT_EQ(buffer.length(), 0); + } + + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 101); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 2); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {2, 0, 0, 0}); // 2 = compact proto, 0 = num transforms, pad, pad + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(87U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Compact)); + EXPECT_THAT(metadata, HasSequenceId(2)); + EXPECT_THAT(metadata, HasNoHeaders()); + } +} + +TEST(HeaderTransportTest, TransformErrors) { + MessageMetadata metadata; + + // Invalid number of transforms + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 2); // size 8 + addInt8(buffer, 0); // binary proto + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 = invalid num transforms + addSeq(buffer, {0, 0}); // padding + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid header transport transform count -1"); + } + + // Unknown transform ids + for (uint8_t xform_id = 1; xform_id < 5; xform_id++) { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {0, 1, xform_id, 0}); // 0 = binary proto, 1 = num transforms, xform id, pad + + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(86U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Binary)); + EXPECT_THAT(metadata, HasAppException(AppExceptionType::MissingResult, + fmt::format("Unknown transform {}", xform_id))); + } + + // Only the first of multiple errors is reported + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {0, 2, 1, 2}); // 0 = binary proto, 2 = num transforms, xform id 1, xform id 2 + + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(86U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Binary)); + EXPECT_THAT(metadata, HasAppException(AppExceptionType::MissingResult, "Unknown transform 1")); + } +} + +TEST(HeaderTransportTest, InvalidInfoBlock) { + // Unknown info block id + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 1); // size 4 + addSeq(buffer, {0, 0, 2, 0}); // 0 = binary proto, 0 = num transforms, 2 = unknown info id, pad + + // Unknown info id is ignored. + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(86U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Binary)); + EXPECT_THAT(metadata, HasSequenceId(1)); + EXPECT_THAT(metadata, HasNoHeaders()); + EXPECT_EQ(buffer.length(), 0); + } + + // Num headers info info block id 1 must be >= 0 + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 3); // size 12 + addSeq(buffer, {0, 0, 1}); // 0 = binary proto, 0 = num transforms, 1 key-value + addSeq(buffer, {0xFF, 0xFF, 0xFF, 0xFF, 0x1F}); // -1 headers + addSeq(buffer, {0, 0, 0, 0}); + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "invalid header transport header count -1"); + } + + // Header key length exceeds max allowed size + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 2); // size 8 + addSeq(buffer, {0, 0, 1, 1}); // 0 = binary proto, 0 = num transforms, 1 key-value, 1 = num kvs + addSeq(buffer, {0x80, 0x80, 0x40}); // var int 0x100000 + addInt8(buffer, 0); + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "header transport header key: value 1048576 exceeds max i16 (32767)"); + } + + // Header key extends past stated header size + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 2); // size 8 + addSeq(buffer, {0, 0, 1, 1}); // 0 = binary proto, 0 = num transforms, 1 key-value, 1 = num kvs + addInt8(buffer, 4); // exceeds specified header size + addString(buffer, "key_"); + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "unable to read header transport header key: header too small"); + } + + // Header key ends at stated header size (no value) + { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + + addInt32(buffer, 100); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 2); // size 8 + addSeq(buffer, {0, 0, 1, 1}); // 0 = binary proto, 0 = num transforms, 1 key-value, 1 = num kvs + addInt8(buffer, 3); // head ends with key, no room for value + addString(buffer, "abc"); + addInt8(buffer, 0); + + EXPECT_THROW_WITH_MESSAGE(transport.decodeFrameStart(buffer, metadata), EnvoyException, + "unable to read header transport header value: header too small"); + } +} + +TEST(HeaderTransportTest, InfoBlock) { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.addHeader(Header("not", "empty")); + + addInt32(buffer, 200); + addInt16(buffer, 0x0FFF); + addInt16(buffer, 0); + addInt32(buffer, 1); // sequence number + addInt16(buffer, 38); // size 152 + addSeq(buffer, {0, 0, 1, 3}); // 0 = binary proto, 0 = num transforms, 1 = key value, 3 = num kvs + addInt8(buffer, 3); + addString(buffer, "key"); + addInt8(buffer, 5); + addString(buffer, "value"); + addInt8(buffer, 4); + addString(buffer, "key2"); + addSeq(buffer, {0x80, 0x01}); // var int 128 + addString(buffer, std::string(128, 'x')); + addInt8(buffer, 0); // empty key + addInt8(buffer, 0); // empty value + addInt8(buffer, 0); // padding + + HeaderMap expected_headers{ + {"not", "empty"}, + {"key", "value"}, + {"key2", std::string(128, 'x')}, + {"", ""}, + }; + + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(38U)); + EXPECT_EQ(expected_headers, metadata.headers()); + EXPECT_EQ(buffer.length(), 0); +} + +TEST(HeaderTransportTest, DecodeFrameEnd) { + HeaderTransportImpl transport; + Buffer::OwnedImpl buffer; + EXPECT_TRUE(transport.decodeFrameEnd(buffer)); +} + +TEST(HeaderTransportImpl, TestEncodeFrame) { + HeaderTransportImpl transport; + + // No message + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + Buffer::OwnedImpl msg; + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "invalid thrift header transport message size 0"); + } + + // No protocol + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + Buffer::OwnedImpl msg; + msg.add("fake message"); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "missing header transport protocol"); + } + + // Illegal protocol + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Auto); + Buffer::OwnedImpl msg; + msg.add("fake message"); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "invalid header transport protocol auto"); + } + + // Message too large + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Binary); + + MockBuffer msg; + EXPECT_CALL(msg, length()).WillOnce(Return(0x40000000)); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "invalid thrift header transport frame size 1073741838"); + } + + // Too many headers + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata = mkMessageMetadata(32769); + metadata.setProtocol(ProtocolType::Binary); + + Buffer::OwnedImpl msg; + msg.add("fake message"); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "invalid thrift header transport too many headers 32769"); + } + + // Header string too large + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Binary); + metadata.addHeader(Header("key", std::string(32768, 'x'))); + + Buffer::OwnedImpl msg; + msg.add("fake message"); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "header string too long: 32768"); + } + + // Header info block too large + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Binary); + metadata.addHeader(Header("k1", std::string(16384, 'x'))); + metadata.addHeader(Header("k2", std::string(16384, 'x'))); + metadata.addHeader(Header("k3", std::string(16384, 'x'))); + metadata.addHeader(Header("k4", std::string(16384, 'x'))); + + Buffer::OwnedImpl msg; + msg.add("fake message"); + + EXPECT_THROW_WITH_MESSAGE(transport.encodeFrame(buffer, metadata, msg), EnvoyException, + "invalid thrift header transport header size 65568"); + } + + // Trivial frame with binary protocol + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Binary); + Buffer::OwnedImpl msg; + msg.add("fake message"); + + transport.encodeFrame(buffer, metadata, msg); + + EXPECT_EQ(0, msg.length()); + EXPECT_EQ(std::string("\0\0\0\x1a" + "\xf\xff\0\0" + "\0\0\0\0" + "\0\x1" + "\0\0\0\0" + "fake message", + 30), + buffer.toString()); + } + + // Trivial frame with compact protocol + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Compact); + metadata.setSequenceId(10); + Buffer::OwnedImpl msg; + msg.add("fake message"); + + transport.encodeFrame(buffer, metadata, msg); + + EXPECT_EQ(0, msg.length()); + EXPECT_EQ(std::string("\0\0\0\x1a" + "\xf\xff\0\0" + "\0\0\0\x0a" + "\0\x1" // header size = 4 + "\x2\0\0\0" // compact, no transforms, padding + "fake message", + 30), + buffer.toString()); + } + + // Frame with headers + { + Buffer::OwnedImpl buffer; + MessageMetadata metadata; + metadata.setProtocol(ProtocolType::Compact); + metadata.setSequenceId(10); + metadata.addHeader(Header("key", "value")); + metadata.addHeader(Header("", "")); + Buffer::OwnedImpl msg; + msg.add("fake message"); + + transport.encodeFrame(buffer, metadata, msg); + + EXPECT_EQ(0, msg.length()); + EXPECT_EQ(std::string("\0\0\0\x2a" + "\xf\xff\0\0" + "\0\0\0\x0a" + "\0\x5" // header size = 20 + "\x2\0" // compact, no transforms + "\x1\x2" // header info block, 2 headers + "\x3key\x5value" // first header + "\0\0" // second header + "\0\0\0\0" // padding + "fake message", + 46), + buffer.toString()); + } +} + +} // namespace ThriftProxy +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/thrift_proxy/integration_test.cc b/test/extensions/filters/network/thrift_proxy/integration_test.cc index b50aa45e0afcf..59ba538154077 100644 --- a/test/extensions/filters/network/thrift_proxy/integration_test.cc +++ b/test/extensions/filters/network/thrift_proxy/integration_test.cc @@ -163,11 +163,12 @@ paramToString(const TestParamInfo>& p return fmt::format("{}{}", transport, protocol); } -INSTANTIATE_TEST_CASE_P( - TransportAndProtocol, ThriftConnManagerIntegrationTest, - Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED), - Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), Values(false, true)), - paramToString); +INSTANTIATE_TEST_CASE_P(TransportAndProtocol, ThriftConnManagerIntegrationTest, + Combine(Values(TransportNames::get().FRAMED, TransportNames::get().UNFRAMED, + TransportNames::get().HEADER), + Values(ProtocolNames::get().BINARY, ProtocolNames::get().COMPACT), + Values(false, true)), + paramToString); TEST_P(ThriftConnManagerIntegrationTest, Success) { initializeCall(CallResult::Success); diff --git a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc index da9f0f90ed93f..621f826c46873 100644 --- a/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/protocol_impl_test.cc @@ -342,6 +342,18 @@ TEST_F(AutoProtocolTest, Type) { EXPECT_EQ(proto.type(), ProtocolType::Auto); } +TEST_F(AutoProtocolTest, SetUnexpectedType) { + Buffer::OwnedImpl buffer; + AutoProtocolImpl proto; + resetMetadata(); + + addInt16(buffer, 0x0102); + + proto.setType(ProtocolType::Auto); + EXPECT_THROW_WITH_MESSAGE(proto.readMessageBegin(buffer, metadata_), EnvoyException, + "unknown thrift auto protocol message start 0102"); +} + } // namespace ThriftProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc index 93ddab8928a63..adf502ea96f74 100644 --- a/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc +++ b/test/extensions/filters/network/thrift_proxy/transport_impl_test.cc @@ -131,6 +131,50 @@ TEST(AutoTransportTest, DecodeFrameStart) { EXPECT_EQ(transport.type(), TransportType::Unframed); EXPECT_EQ(buffer.length(), 8); } + + // Header transport + binary protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFF); + addInt16(buffer, 0x0FFF); // header magic + addInt16(buffer, 0x0000); + addInt32(buffer, 0xEE); // sequence id + addInt16(buffer, 1); + addInt32(buffer, 0); // protocol (binary), 0 transforms + padding + addInt16(buffer, 0x8001); + + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(241U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Binary)); + EXPECT_THAT(metadata, HasSequenceId(0xEE)); + EXPECT_EQ(transport.name(), "header(auto)"); + EXPECT_EQ(transport.type(), TransportType::Header); + EXPECT_EQ(buffer.length(), 2); + } + + // Header transport + compact protocol + { + AutoTransportImpl transport; + Buffer::OwnedImpl buffer; + addInt32(buffer, 0xFF); + addInt16(buffer, 0x0FFF); // header magic + addInt16(buffer, 0x0000); + addInt32(buffer, 0xEE); // sequence id + addInt16(buffer, 1); + addInt32(buffer, 0x02000000); // protocol (binary), 0 transforms + padding + addInt16(buffer, 0x8201); + + MessageMetadata metadata; + EXPECT_TRUE(transport.decodeFrameStart(buffer, metadata)); + EXPECT_THAT(metadata, HasFrameSize(241U)); + EXPECT_THAT(metadata, HasProtocol(ProtocolType::Compact)); + EXPECT_THAT(metadata, HasSequenceId(0xEE)); + EXPECT_EQ(transport.name(), "header(auto)"); + EXPECT_EQ(transport.type(), TransportType::Header); + EXPECT_EQ(buffer.length(), 2); + } } TEST(AutoTransportTest, DecodeFrameEnd) {