diff --git a/api/bazel/repositories.bzl b/api/bazel/repositories.bzl index f7ae937e642fa..62cd26e4f445d 100644 --- a/api/bazel/repositories.bzl +++ b/api/bazel/repositories.bzl @@ -31,6 +31,11 @@ def api_dependencies(): locations = REPOSITORY_LOCATIONS, build_file_content = OPENCENSUSTRACE_BUILD_CONTENT, ) + envoy_http_archive( + name = "kafka_source", + locations = REPOSITORY_LOCATIONS, + build_file_content = KAFKASOURCE_BUILD_CONTENT, + ) GOOGLEAPIS_BUILD_CONTENT = """ load("@com_google_protobuf//:protobuf.bzl", "cc_proto_library", "py_proto_library") @@ -285,3 +290,15 @@ go_proto_library( visibility = ["//visibility:public"], ) """ + +KAFKASOURCE_BUILD_CONTENT = """ + +filegroup( + name = "request_protocol_files", + srcs = glob([ + "*Request.json", + ]), + visibility = ["//visibility:public"], +) + +""" diff --git a/api/bazel/repository_locations.bzl b/api/bazel/repository_locations.bzl index e4489eb3b17bd..fd025483dea1e 100644 --- a/api/bazel/repository_locations.bzl +++ b/api/bazel/repository_locations.bzl @@ -16,6 +16,8 @@ GOOGLEAPIS_SHA = "16f5b2e8bf1e747a32f9a62e211f8f33c94645492e9bbd72458061d9a9de1f PROMETHEUS_GIT_SHA = "99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c" # Nov 17, 2017 PROMETHEUS_SHA = "783bdaf8ee0464b35ec0c8704871e1e72afa0005c3f3587f65d9d6694bf3911b" +KAFKA_SOURCE_SHA = "ae7a1696c0a0302b43c5b21e515c37e6ecd365941f68a510a7e442eebddf39a1" # 2.2.0-rc2 + REPOSITORY_LOCATIONS = dict( bazel_skylib = dict( sha256 = BAZEL_SKYLIB_SHA256, @@ -48,4 +50,9 @@ REPOSITORY_LOCATIONS = dict( strip_prefix = "opencensus-proto-" + OPENCENSUS_RELEASE + "/src/opencensus/proto/trace/v1", urls = ["https://github.com/census-instrumentation/opencensus-proto/archive/v" + OPENCENSUS_RELEASE + ".tar.gz"], ), + kafka_source = dict( + sha256 = KAFKA_SOURCE_SHA, + strip_prefix = "kafka-2.2.0-rc2/clients/src/main/resources/common/message", + urls = ["https://github.com/apache/kafka/archive/2.2.0-rc2.zip"], + ), ) diff --git a/source/common/common/logger.h b/source/common/common/logger.h index 3a804c8f4f5b5..1a5f74048fd2f 100644 --- a/source/common/common/logger.h +++ b/source/common/common/logger.h @@ -37,6 +37,7 @@ namespace Logger { FUNCTION(http2) \ FUNCTION(hystrix) \ FUNCTION(init) \ + FUNCTION(kafka) \ FUNCTION(lua) \ FUNCTION(main) \ FUNCTION(misc) \ diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index f973d63fad0a7..183dc0e05d32c 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -69,6 +69,9 @@ EXTENSIONS = { "envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", "envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", "envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", + # NOTE: Kafka filter does not have a proper filter implemented right now. We are referencing to + # codec implementation that is going to be used by the filter. + "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", "envoy.filters.network.mysql_proxy": "//source/extensions/filters/network/mysql_proxy:config", "envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD new file mode 100644 index 0000000000000..73bab5124c8e7 --- /dev/null +++ b/source/extensions/filters/network/kafka/BUILD @@ -0,0 +1,137 @@ +licenses(["notice"]) # Apache 2 + +# Kafka network filter. +# Public docs: docs/root/configuration/network_filters/kafka_filter.rst + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "kafka_request_codec_lib", + srcs = ["request_codec.cc"], + hdrs = [ + "codec.h", + "request_codec.h", + ], + deps = [ + ":kafka_request_parser_lib", + "//source/common/buffer:buffer_lib", + ], +) + +envoy_cc_library( + name = "kafka_request_parser_lib", + srcs = [ + "external/kafka_request_resolver.cc", + "kafka_request_parser.cc", + ], + hdrs = [ + "external/requests.h", + "kafka_request_parser.h", + ], + deps = [ + ":kafka_request_lib", + ":parser_lib", + "//source/common/common:assert_lib", + "//source/common/common:minimal_logger_lib", + ], +) + +envoy_cc_library( + name = "kafka_request_lib", + srcs = [ + ], + hdrs = [ + "kafka_request.h", + ], + deps = [ + ":serialization_lib", + ], +) + +genrule( + name = "kafka_generated_source", + srcs = [ + "@kafka_source//:request_protocol_files", + ], + outs = [ + "external/requests.h", + "external/kafka_request_resolver.cc", + ], + cmd = """ + ./$(location :kafka_code_generator) generate-source \ + $(location external/requests.h) $(location external/kafka_request_resolver.cc) \ + $(SRCS) + """, + tools = [ + ":kafka_code_generator", + ], +) + +py_binary( + name = "kafka_code_generator", + srcs = ["protocol_code_generator/kafka_generator.py"], + data = glob(["protocol_code_generator/*.j2"]), + main = "protocol_code_generator/kafka_generator.py", + deps = ["@com_github_pallets_jinja//:jinja2"], +) + +envoy_cc_library( + name = "parser_lib", + hdrs = ["parser.h"], + deps = [ + "//source/common/common:minimal_logger_lib", + ], +) + +envoy_cc_library( + name = "serialization_lib", + hdrs = [ + "external/serialization_composite.h", + "serialization.h", + ], + deps = [ + ":kafka_types_lib", + "//include/envoy/buffer:buffer_interface", + "//source/common/common:byte_order_lib", + ], +) + +genrule( + name = "serialization_composite_generated_source", + srcs = [], + outs = [ + "external/serialization_composite.h", + ], + cmd = """ + ./$(location :serialization_composite_generator) generate-source \ + $(location external/serialization_composite.h) + """, + tools = [ + ":serialization_composite_generator", + ], +) + +py_binary( + name = "serialization_composite_generator", + srcs = ["serialization_code_generator/serialization_composite_generator.py"], + data = glob(["serialization_code_generator/*.j2"]), + main = "serialization_code_generator/serialization_composite_generator.py", + deps = ["@com_github_pallets_jinja//:jinja2"], +) + +envoy_cc_library( + name = "kafka_types_lib", + hdrs = [ + "kafka_types.h", + ], + external_deps = ["abseil_optional"], + deps = [ + "//source/common/common:macros", + ], +) diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h new file mode 100644 index 0000000000000..a58c284a052a1 --- /dev/null +++ b/source/extensions/filters/network/kafka/codec.h @@ -0,0 +1,43 @@ +#pragma once + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Kafka message decoder. + */ +class MessageDecoder { +public: + virtual ~MessageDecoder() = default; + + /** + * Processes given buffer attempting to decode messages contained within. + * @param data buffer instance. + */ + virtual void onData(Buffer::Instance& data) PURE; +}; + +/** + * Kafka message encoder. + * @param MessageType encoded message type (request or response). + */ +template class MessageEncoder { +public: + virtual ~MessageEncoder() = default; + + /** + * Encodes given message. + * @param message message to be encoded. + */ + virtual void encode(const MessageType& message) PURE; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h new file mode 100644 index 0000000000000..e15605515b75a --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -0,0 +1,112 @@ +#pragma once + +#include "envoy/common/exception.h" + +#include "extensions/filters/network/kafka/external/serialization_composite.h" +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Represents fields that are present in every Kafka request message. + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +struct RequestHeader { + int16_t api_key_; + int16_t api_version_; + int32_t correlation_id_; + NullableString client_id_; + + bool operator==(const RequestHeader& rhs) const { + return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && + correlation_id_ == rhs.correlation_id_ && client_id_ == rhs.client_id_; + }; +}; + +/** + * Carries information that could be extracted during the failed parse. + */ +class RequestParseFailure { +public: + RequestParseFailure(const RequestHeader& request_header) : request_header_{request_header} {}; + + /** + * Request's header. + */ + const RequestHeader request_header_; +}; + +typedef std::shared_ptr RequestParseFailureSharedPtr; + +/** + * Abstract Kafka request. + * Contains data present in every request (the header with request key, version, etc.). + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +class AbstractRequest { +public: + virtual ~AbstractRequest() = default; + + /** + * Constructs a request with given header data. + * @param request_header request's header. + */ + AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {}; + + /** + * Encode the contents of this message into a given buffer. + * @param dst buffer instance to keep serialized message + */ + virtual size_t encode(Buffer::Instance& dst) const PURE; + + /** + * Request's header. + */ + const RequestHeader request_header_; +}; + +typedef std::shared_ptr AbstractRequestSharedPtr; + +/** + * Concrete request that carries data particular to given request type. + * @param Data concrete request data type. + */ +template class Request : public AbstractRequest { +public: + /** + * Request header fields need to be initialized by user in case of newly created requests. + */ + Request(const RequestHeader& request_header, const Data& data) + : AbstractRequest{request_header}, data_{data} {}; + + /** + * Encodes given request into a buffer, with any extra configuration carried by the context. + */ + size_t encode(Buffer::Instance& dst) const override { + EncodingContext context{request_header_.api_version_}; + size_t written{0}; + // Encode request header. + written += context.encode(request_header_.api_key_, dst); + written += context.encode(request_header_.api_version_, dst); + written += context.encode(request_header_.correlation_id_, dst); + written += context.encode(request_header_.client_id_, dst); + // Encode request-specific data. + written += context.encode(data_, dst); + return written; + } + + bool operator==(const Request& rhs) const { + return request_header_ == rhs.request_header_ && data_ == rhs.data_; + }; + +private: + const Data data_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc new file mode 100644 index 0000000000000..0245a9bff24b9 --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -0,0 +1,62 @@ +#include "extensions/filters/network/kafka/kafka_request_parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +const RequestParserResolver& RequestParserResolver::getDefaultInstance() { + CONSTRUCT_ON_FIRST_USE(RequestParserResolver); +} + +RequestParseResponse RequestStartParser::parse(absl::string_view& data) { + request_length_.feed(data); + if (request_length_.ready()) { + context_->remaining_request_size_ = request_length_.get(); + return RequestParseResponse::nextParser( + std::make_shared(parser_resolver_, context_)); + } else { + return RequestParseResponse::stillWaiting(); + } +} + +RequestParseResponse RequestHeaderParser::parse(absl::string_view& data) { + const absl::string_view orig_data = data; + try { + context_->remaining_request_size_ -= deserializer_->feed(data); + } catch (const EnvoyException& e) { + // We were unable to compute the request header, but we still need to consume rest of request + // (some of the data might have been consumed during this attempt). + const int32_t consumed = static_cast(orig_data.size() - data.size()); + context_->remaining_request_size_ -= consumed; + context_->request_header_ = {-1, -1, -1, absl::nullopt}; + return RequestParseResponse::nextParser(std::make_shared(context_)); + } + + if (deserializer_->ready()) { + RequestHeader request_header = deserializer_->get(); + context_->request_header_ = request_header; + RequestParserSharedPtr next_parser = parser_resolver_.createParser( + request_header.api_key_, request_header.api_version_, context_); + return RequestParseResponse::nextParser(next_parser); + } else { + return RequestParseResponse::stillWaiting(); + } +} + +RequestParseResponse SentinelParser::parse(absl::string_view& data) { + const size_t min = std::min(context_->remaining_request_size_, data.size()); + data = {data.data() + min, data.size() - min}; + context_->remaining_request_size_ -= min; + if (0 == context_->remaining_request_size_) { + return RequestParseResponse::parseFailure( + std::make_shared(context_->request_header_)); + } else { + return RequestParseResponse::stillWaiting(); + } +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h new file mode 100644 index 0000000000000..861d4dc4a3a9d --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -0,0 +1,194 @@ +#pragma once + +#include + +#include "envoy/common/exception.h" + +#include "common/common/assert.h" + +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +using RequestParseResponse = ParseResponse; +using RequestParser = Parser; +using RequestParserSharedPtr = std::shared_ptr; + +/** + * Context that is shared between parsers that are handling the same single message. + */ +struct RequestContext { + int32_t remaining_request_size_{0}; + RequestHeader request_header_{}; +}; + +typedef std::shared_ptr RequestContextSharedPtr; + +/** + * Request decoder configuration object. + * Resolves the parser that will be responsible for consuming the request-specific data. + * In other words: provides (api_key, api_version) -> Parser function. + */ +class RequestParserResolver { +public: + virtual ~RequestParserResolver() = default; + + /** + * Creates a parser that is going to process data specific for given api_key & api_version. + * @param api_key request type. + * @param api_version request version. + * @param context context to be used by parser. + * @return parser that is capable of processing data for given request type & version. + */ + virtual RequestParserSharedPtr createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const; + + /** + * Return default resolver, that uses request's api key and version to provide a matching parser. + */ + static const RequestParserResolver& getDefaultInstance(); +}; + +/** + * Request parser responsible for consuming request length and setting up context with this data. + * @see http://kafka.apache.org/protocol.html#protocol_common + */ +class RequestStartParser : public RequestParser { +public: + RequestStartParser(const RequestParserResolver& parser_resolver) + : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; + + /** + * Consumes 4 bytes (INT32) as request length and updates the context with that value. + * @return RequestHeaderParser instance to process request header. + */ + RequestParseResponse parse(absl::string_view& data) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + Int32Deserializer request_length_; +}; + +/** + * Deserializer that extracts request header (4 fields). + * Can throw, as one of the fields (client-id) can throw (nullable string with invalid length). + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +class RequestHeaderDeserializer + : public CompositeDeserializerWith4Delegates {}; + +typedef std::unique_ptr RequestHeaderDeserializerPtr; + +/** + * Parser responsible for extracting the request header and putting it into context. + * On a successful parse the resolved data (api_key & api_version) is used to determine the next + * parser. + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +class RequestHeaderParser : public RequestParser { +public: + // Default constructor. + RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) + : RequestHeaderParser{parser_resolver, context, + std::make_unique()} {}; + + // Constructor visible for testing (allows for initial parser injection). + RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context, + RequestHeaderDeserializerPtr deserializer) + : parser_resolver_{parser_resolver}, context_{context}, deserializer_{ + std::move(deserializer)} {}; + + /** + * Uses data provided to compute request header. + * @return Parser instance responsible for processing rest of the message + */ + RequestParseResponse parse(absl::string_view& data) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + RequestHeaderDeserializerPtr deserializer_; +}; + +/** + * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported + * api_key & api_version. It does not attempt to capture any data, just throws it away until end of + * message. + */ +class SentinelParser : public RequestParser { +public: + SentinelParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Returns failed parse data. Ignores (jumps over) the data provided. + */ + RequestParseResponse parse(absl::string_view& data) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestContextSharedPtr context_; +}; + +/** + * Request parser uses a single deserializer to construct a request object. + * This parser is responsible for consuming request-specific data (e.g. topic names) and always + * returns a parsed message. + * @param RequestType request class. + * @param DeserializerType deserializer type corresponding to request class (should be subclass of + * Deserializer). + */ +template +class RequestDataParser : public RequestParser { +public: + /** + * Create a parser with given context. + * @param context parse context containing request header. + */ + RequestDataParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Consume enough data to fill in deserializer and receive the parsed request. + * Fill in request's header with data stored in context. + */ + RequestParseResponse parse(absl::string_view& data) override { + context_->remaining_request_size_ -= deserializer.feed(data); + + if (deserializer.ready()) { + if (0 == context_->remaining_request_size_) { + // After a successful parse, there should be nothing left - we have consumed all the bytes. + AbstractRequestSharedPtr msg = + std::make_shared>(context_->request_header_, deserializer.get()); + return RequestParseResponse::parsedMessage(msg); + } else { + // The message makes no sense, the deserializer that matches the schema consumed all + // necessary data, but there are still bytes in this message. + return RequestParseResponse::nextParser(std::make_shared(context_)); + } + } else { + return RequestParseResponse::stillWaiting(); + } + } + + const RequestContextSharedPtr contextForTest() const { return context_; } + +protected: + RequestContextSharedPtr context_; + DeserializerType deserializer; // underlying request-specific deserializer +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h new file mode 100644 index 0000000000000..71d1ce920a82d --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -0,0 +1,37 @@ +#pragma once + +#include +#include +#include + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Nullable string used by Kafka. + */ +typedef absl::optional NullableString; + +/** + * Bytes array used by Kafka. + */ +typedef std::vector Bytes; + +/** + * Nullable bytes array used by Kafka. + */ +typedef absl::optional NullableBytes; + +/** + * Kafka array of elements of type T. + */ +template using NullableArray = absl::optional>; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h new file mode 100644 index 0000000000000..031aaaef1dbd3 --- /dev/null +++ b/source/extensions/filters/network/kafka/parser.h @@ -0,0 +1,94 @@ +#pragma once + +#include + +#include "common/common/logger.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +template class ParseResponse; + +/** + * Parser is responsible for consuming data relevant to some part of a message, and then returning + * the decision how the parsing should continue. + */ +template +class Parser : public Logger::Loggable { +public: + virtual ~Parser() = default; + + /** + * Submit data to be processed by parser, will consume as much data as it is necessary to reach + * the conclusion what should be the next parse step. + * @param data bytes to be processed, will be updated by parser if any have been consumed. + * @return parse status - decision what should be done with current parser (keep/replace). + */ + virtual ParseResponse parse(absl::string_view& data) PURE; +}; + +template +using ParserSharedPtr = std::shared_ptr>; + +/** + * Three-state holder representing one of: + * - parser still needs data (`stillWaiting`), + * - parser is finished, and following parser should be used to process the rest of data + * (`nextParser`), + * - parser is finished, and parse result is attached (`parsedMessage` or `parseFailure`). + */ +template class ParseResponse { +public: + /** + * Constructs a response that states that parser still needs data and should not be replaced. + */ + static ParseResponse stillWaiting() { return {nullptr, nullptr, nullptr}; } + + /** + * Constructs a response that states that parser is finished and should be replaced by given + * parser. + */ + static ParseResponse nextParser(ParserSharedPtr next_parser) { + return {next_parser, nullptr, nullptr}; + }; + + /** + * Constructs a response that states that parser is finished, the message is ready, and parsing + * can start anew for next message. + */ + static ParseResponse parsedMessage(MessageType message) { return {nullptr, message, nullptr}; }; + + /** + * Constructs a response that states that parser is finished, the message could not be parsed + * properly, and parsing can start anew for next message. + */ + static ParseResponse parseFailure(FailureDataType failure_data) { + return {nullptr, nullptr, failure_data}; + }; + + /** + * If response contains a next parser or a parse result. + */ + bool hasData() const { + return (next_parser_ != nullptr) || (message_ != nullptr) || (failure_data_ != nullptr); + } + +private: + ParseResponse(ParserSharedPtr parser, MessageType message, + FailureDataType failure_data) + : next_parser_{parser}, message_{message}, failure_data_{failure_data} {}; + +public: + ParserSharedPtr next_parser_; + MessageType message_; + FailureDataType failure_data_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 new file mode 100644 index 0000000000000..46251d5fb5af0 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 @@ -0,0 +1,64 @@ +{# + Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest). + Rendered templates for each structure in Kafka protocol will be put into 'requests.h' file. + + Each structure is capable of holding all versions of given entity (what means its fields are + actually a superset of union of all versions' fields). Each version has a dedicated deserializer + (named $requestV$versionDeserializer), which calls the matching constructor. + + To serialize, it is necessary to pass the encoding context (that contains the version that's + being serialized). Depending on the version, the fields will be written to the buffer. +#} +struct {{ complex_type.name }} { + + {# + Constructors invoked by deserializers. + Each constructor has a signature that matches the fields in at least one version (as sometimes + there are different Kafka versions that are actually composed of precisely the same fields). + #} + {% for field in complex_type.fields %} + const {{ field.field_declaration() }}_;{% endfor %} + {% for constructor in complex_type.compute_constructors() %} + // constructor used in versions: {{ constructor['versions'] }} + {{ constructor['full_declaration'] }}{% endfor %} + + {# For every field that's used in version, just serialize it. #} + {% if complex_type.fields|length > 0 %} + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + const int16_t api_version = encoder.apiVersion(); + size_t written{0};{% for field in complex_type.fields %} + if (api_version >= {{ field.version_usage[0] }} + && api_version < {{ field.version_usage[-1] + 1 }}) { + written += encoder.encode({{ field.name }}_, dst); + }{% endfor %} + return written; + } + {% else %} + size_t encode(Buffer::Instance&, EncodingContext&) const { + return 0; + } + {% endif %} + + {% if complex_type.fields|length > 0 %} + bool operator==(const {{ complex_type.name }}& rhs) const { + {% else %} + bool operator==(const {{ complex_type.name }}&) const { + {% endif %} + return true{% for field in complex_type.fields %} + && {{ field.name }}_ == rhs.{{ field.name }}_{% endfor %}; + }; + +}; + +{# + Each structure version has a deserializer that matches the structure's field list. +#} +{% for field_list in complex_type.compute_field_lists() %} +class {{ complex_type.name }}V{{ field_list.version }}Deserializer: + public CompositeDeserializerWith{{ field_list.field_count() }}Delegates< + {{ complex_type.name }} + {% for field in field_list.used_fields() %}, + {{ field.deserializer_name_in_version(field_list.version) }} + {% endfor %}>{}; +{% endfor %} + diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py new file mode 100755 index 0000000000000..663c8a58fbf84 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -0,0 +1,532 @@ +#!/usr/bin/python + + +def main(): + """ + Kafka header generator script + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Generates C++ headers from Kafka protocol specification. + Can generate both main source code, as well as test code. + + Usage: + kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES + where: + COMMAND : 'generate-source', to generate source files, + 'generate-test', to generate test files. + OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', + if generate-test: location of 'requests_test.cc', 'request_codec_request_test.cc'. + INPUT_FILES: Kafka protocol json files to be processed. + + Kafka spec files are provided in Kafka clients jar file. + When generating source code, it creates: + - requests.h - definition of all the structures/deserializers/parsers related to Kafka requests, + - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from + requests.h. + When generating test code, it creates: + - requests_test.cc - serialization/deserialization tests for kafka structures, + - request_codec_request_test.cc - test for all request operations using the codec API. + + Templates used are: + - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2, + - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2, + - to create 'requests_test.cc': requests_test_cc.j2, + - to create 'request_codec_request_test.cc' - request_codec_request_test_cc.j2. + """ + + import sys + import os + + command = sys.argv[1] + if 'generate-source' == command: + requests_h_file = os.path.abspath(sys.argv[2]) + kafka_request_resolver_cc_file = os.path.abspath(sys.argv[3]) + input_files = sys.argv[4:] + elif 'generate-test' == command: + requests_test_cc_file = os.path.abspath(sys.argv[2]) + request_codec_request_test_cc_file = os.path.abspath(sys.argv[3]) + input_files = sys.argv[4:] + else: + raise ValueError('invalid command: ' + command) + + import re + import json + + requests = [] + + # For each request specification file, remove comments, and parse the remains. + for input_file in input_files: + with open(input_file, 'r') as fd: + raw_contents = fd.read() + without_comments = re.sub(r'//.*\n', '', raw_contents) + request_spec = json.loads(without_comments) + request = parse_request(request_spec) + requests.append(request) + + # Sort requests by api_key. + requests.sort(key=lambda x: x.get_extra('api_key')) + + # Generate main source code. + if 'generate-source' == command: + complex_type_template = RenderingHelper.get_template('complex_type_template.j2') + request_parsers_template = RenderingHelper.get_template('request_parser.j2') + + requests_h_contents = '' + + for request in requests: + # For each child structure that is used by request, render its corresponding C++ code. + for dependency in request.declaration_chain: + requests_h_contents += complex_type_template.render(complex_type=dependency) + # Each top-level structure (e.g. FetchRequest) is going to have corresponding parsers. + requests_h_contents += request_parsers_template.render(complex_type=request) + + # Full file with headers, namespace declaration etc. + template = RenderingHelper.get_template('requests_h.j2') + contents = template.render(contents=requests_h_contents) + + with open(requests_h_file, 'w') as fd: + fd.write(contents) + + template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') + contents = template.render(request_types=requests) + + with open(kafka_request_resolver_cc_file, 'w') as fd: + fd.write(contents) + + # Generate test code. + if 'generate-test' == command: + template = RenderingHelper.get_template('requests_test_cc.j2') + contents = template.render(request_types=requests) + + with open(requests_test_cc_file, 'w') as fd: + fd.write(contents) + + template = RenderingHelper.get_template('request_codec_request_test_cc.j2') + contents = template.render(request_types=requests) + + with open(request_codec_request_test_cc_file, 'w') as fd: + fd.write(contents) + + +def parse_request(spec): + """ + Parse a given structure into a request. + Request is just a complex type, that has name & version information kept in differently named + fields, compared to sub-structures in a request. + """ + request_type_name = spec['name'] + request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) + return parse_complex_type(request_type_name, spec, request_versions).with_extra( + 'api_key', spec['apiKey']) + + +def parse_complex_type(type_name, field_spec, versions): + """ + Parse given complex type, returning a structure that holds its name, field specification and + allowed versions. + """ + fields = [] + for child_field in field_spec['fields']: + child = parse_field(child_field, versions[-1]) + fields.append(child) + return Complex(type_name, fields, versions) + + +def parse_field(field_spec, highest_possible_version): + """ + Parse given field, returning a structure holding the name, type, and versions when this field is + actually used (nullable or not). Obviously, field cannot be used in version higher than its + type's usage. + """ + version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + version_usage_as_nullable = Statics.parse_version_string( + field_spec['nullableVersions'], + highest_possible_version) if 'nullableVersions' in field_spec else range(-1) + parsed_type = parse_type(field_spec['type'], field_spec, highest_possible_version) + return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) + + +def parse_type(type_name, field_spec, highest_possible_version): + """ + Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one. + """ + if (type_name.startswith('[]')): + # In spec files, array types are defined as `[]underlying_type` instead of having its own + # element with type inside. + underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) + return Array(underlying_type) + else: + if (type_name in Primitive.PRIMITIVE_TYPE_NAMES): + return Primitive(type_name, field_spec.get('default')) + else: + versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + return parse_complex_type(type_name, field_spec, versions) + + +class Statics: + + @staticmethod + def parse_version_string(raw_versions, highest_possible_version): + """ + Return integer range that corresponds to version string in spec file. + """ + if raw_versions.endswith('+'): + return range(int(raw_versions[:-1]), highest_possible_version + 1) + else: + if '-' in raw_versions: + tokens = raw_versions.split('-', 1) + return range(int(tokens[0]), int(tokens[1]) + 1) + else: + single_version = int(raw_versions) + return range(single_version, single_version + 1) + + +class FieldList: + """ + List of fields used by given entity (request or child structure) in given request version + (as fields get added or removed across versions). + """ + + def __init__(self, version, fields): + self.version = version + self.fields = fields + + def used_fields(self): + """ + Return list of fields that are actually used in this version of structure. + """ + return filter(lambda x: x.used_in_version(self.version), self.fields) + + def constructor_signature(self): + """ + Return constructor signature. + Multiple versions of the same structure can have identical signatures (due to version bumps in + Kafka). + """ + parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) + return ', '.join(parameter_spec) + + def constructor_init_list(self): + """ + Renders member initialization list in constructor. + Takes care of potential optional conversions (as field could be T in V1, but optional + in V2). + """ + init_list = [] + for field in self.fields: + if field.used_in_version(self.version): + if field.is_nullable(): + if field.is_nullable_in_version(self.version): + # Field is optional, and the parameter is optional in this version. + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is optional, and the parameter is T in this version. + init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is T, so parameter cannot be optional. + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # Field is not used in this version, so we need to put in default value. + init_list_item = '%s_{%s}' % (field.name, field.default_value()) + init_list.append(init_list_item) + pass + return ', '.join(init_list) + + def field_count(self): + return len(self.used_fields()) + + def example_value(self): + return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) + + +class FieldSpec: + """ + Represents a field present in a structure (request, or child structure thereof). + Contains name, type, and versions when it is used (nullable or not). + """ + + def __init__(self, name, type, version_usage, version_usage_as_nullable): + import re + separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() + self.type = type + self.version_usage = version_usage + self.version_usage_as_nullable = version_usage_as_nullable + + def is_nullable(self): + return len(self.version_usage_as_nullable) > 0 + + def is_nullable_in_version(self, version): + """ + Whether thie field is nullable in given version. + Fields can be non-nullable in earlier versions. + See https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message#nullable-fields + """ + return version in self.version_usage_as_nullable + + def used_in_version(self, version): + return version in self.version_usage + + def field_declaration(self): + if self.is_nullable(): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def parameter_declaration(self, version): + if self.is_nullable_in_version(version): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def default_value(self): + if self.is_nullable(): + return '{%s}' % self.type.default_value() + else: + return str(self.type.default_value()) + + def example_value_for_test(self, version): + if self.is_nullable(): + return 'absl::make_optional<%s>(%s)' % (self.type.name, + self.type.example_value_for_test(version)) + else: + return str(self.type.example_value_for_test(version)) + + def deserializer_name_in_version(self, version): + if self.is_nullable_in_version(version): + return 'Nullable%s' % self.type.deserializer_name_in_version(version) + else: + return self.type.deserializer_name_in_version(version) + + def is_printable(self): + return self.type.is_printable() + + +class TypeSpecification: + + def deserializer_name_in_version(self, version): + """ + Renders the deserializer name of given type, in request with given version. + """ + raise NotImplementedError() + + def default_value(self): + """ + Returns a default value for given type. + """ + raise NotImplementedError() + + def example_value_for_test(self, version): + raise NotImplementedError() + + def is_printable(self): + raise NotImplementedError() + + +class Array(TypeSpecification): + """ + Represents array complex type. + To use instance of this type, it is necessary to declare structures required by self.underlying + (e.g. to use Array, we need to have `struct Foo {...}`). + """ + + def __init__(self, underlying): + self.underlying = underlying + self.declaration_chain = self.underlying.declaration_chain + + @property + def name(self): + return 'std::vector<%s>' % self.underlying.name + + def deserializer_name_in_version(self, version): + return 'ArrayDeserializer<%s, %s>' % (self.underlying.name, + self.underlying.deserializer_name_in_version(version)) + + def default_value(self): + return '{}' + + def example_value_for_test(self, version): + return 'std::vector<%s>{ %s }' % (self.underlying.name, + self.underlying.example_value_for_test(version)) + + def is_printable(self): + return self.underlying.is_printable() + + +class Primitive(TypeSpecification): + """ + Represents a Kafka primitive value. + """ + + PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] + + KAFKA_TYPE_TO_ENVOY_TYPE = { + 'string': 'std::string', + 'bool': 'bool', + 'int8': 'int8_t', + 'int16': 'int16_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + 'bytes': 'Bytes', + } + + KAFKA_TYPE_TO_DESERIALIZER = { + 'string': 'StringDeserializer', + 'bool': 'BooleanDeserializer', + 'int8': 'Int8Deserializer', + 'int16': 'Int16Deserializer', + 'int32': 'Int32Deserializer', + 'int64': 'Int64Deserializer', + 'bytes': 'BytesDeserializer', + } + + # See https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages + KAFKA_TYPE_TO_DEFAULT_VALUE = { + 'string': '""', + 'bool': 'false', + 'int8': '0', + 'int16': '0', + 'int32': '0', + 'int64': '0', + 'bytes': '{}', + } + + # Custom values that make test code more readable. + KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { + 'string': '"string"', + 'bool': 'false', + 'int8': 'static_cast(8)', + 'int16': 'static_cast(16)', + 'int32': 'static_cast(32)', + 'int64': 'static_cast(64)', + 'bytes': 'Bytes({0, 1, 2, 3})', + } + + def __init__(self, name, custom_default_value): + self.original_name = name + self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) + self.custom_default_value = custom_default_value + self.declaration_chain = [] + self.deserializer_name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) + + @staticmethod + def compute(name, map): + if name in map: + return map[name] + else: + raise ValueError(name) + + def deserializer_name_in_version(self, version): + return self.deserializer_name + + def default_value(self): + if self.custom_default_value is not None: + return self.custom_default_value + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) + + def example_value_for_test(self, version): + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) + + def is_printable(self): + return self.name not in ['Bytes'] + + +class Complex(TypeSpecification): + """ + Represents a complex type (multiple types aggregated into one). + This type gets mapped to a C++ struct. + """ + + def __init__(self, name, fields, versions): + self.name = name + self.fields = fields + self.versions = versions + self.declaration_chain = self.__compute_declaration_chain() + self.attributes = {} + + def __compute_declaration_chain(self): + """ + Computes all dependendencies, what means all non-primitive types used by this type. + They need to be declared before this struct is declared. + """ + result = [] + for field in self.fields: + result.extend(field.type.declaration_chain) + result.append(self) + return result + + def with_extra(self, key, value): + self.attributes[key] = value + return self + + def get_extra(self, key): + return self.attributes[key] + + def compute_constructors(self): + """ + Field lists for different versions may not differ (as Kafka can bump version without any + changes). But constructors need to be unique, so we need to remove duplicates if the signatures + match. + """ + signature_to_constructor = {} + for field_list in self.compute_field_lists(): + signature = field_list.constructor_signature() + constructor = signature_to_constructor.get(signature) + if constructor is None: + entry = {} + entry['versions'] = [field_list.version] + entry['signature'] = signature + if (len(signature) > 0): + entry['full_declaration'] = '%s(%s): %s {};' % (self.name, signature, + field_list.constructor_init_list()) + else: + entry['full_declaration'] = '%s() {};' % self.name + signature_to_constructor[signature] = entry + else: + constructor['versions'].append(field_list.version) + return sorted(signature_to_constructor.values(), key=lambda x: x['versions'][0]) + + def compute_field_lists(self): + """ + Return field lists representing each of structure versions. + """ + field_lists = [] + for version in self.versions: + field_list = FieldList(version, self.fields) + field_lists.append(field_list) + return field_lists + + def deserializer_name_in_version(self, version): + return '%sV%dDeserializer' % (self.name, version) + + def default_value(self): + raise NotImplementedError('unable to create default value of complex type') + + def example_value_for_test(self, version): + field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) + example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) + return '%s(%s)' % (self.name, ', '.join(example_values)) + + def is_printable(self): + return True + + +class RenderingHelper: + """ + Helper for jinja templates. + """ + + @staticmethod + def get_template(template): + import jinja2 + import os + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(searchpath=os.path.dirname(os.path.abspath(__file__)))) + return env.get_template(template) + + +if __name__ == "__main__": + main() diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 new file mode 100644 index 0000000000000..d73f76955adca --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -0,0 +1,37 @@ +{# + Template for 'kafka_request_resolver.cc'. + Defines default Kafka request resolver, that uses request parsers in (also generated) + 'requests.h'. +#} +#include "extensions/filters/network/kafka/external/requests.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Creates a parser that corresponds to provided key and version. + * If corresponding parser cannot be found (what means a newer version of Kafka protocol), + * a sentinel parser is returned. + * @param api_key Kafka request key + * @param api_version Kafka request's version + * @param context parse context + */ +RequestParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const { + +{% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} + if ({{ request_type.get_extra('api_key') }} == api_key + && {{ field_list.version }} == api_version) { + return std::make_shared<{{ request_type.name }}V{{ field_list.version }}Parser>(context); + }{% endfor %}{% endfor %} + return std::make_shared(context); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 new file mode 100644 index 0000000000000..c853563f8f8a9 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 @@ -0,0 +1,89 @@ +{# + Template for 'request_codec_request_test.cc'. + + Provides integration tests using Kafka codec. + The tests do the following: + - create the message, + - serialize the message into buffer, + - pass the buffer to the codec, + - capture messages received in callback, + - verify that captured messages are identical to the ones sent. +#} +#include "extensions/filters/network/kafka/external/requests.h" +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/extensions/filters/network/kafka/serialization_utilities.h" +#include "test/mocks/server/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace RequestCodecRequestTest { + +class RequestCodecRequestTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; +}; + +{% for request_type in request_types %} + +// Integration test for {{ request_type.name }} messages. + +TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { + // given + using RequestUnderTest = Request<{{ request_type.name }}>; + + std::vector sent; + int32_t correlation = 0; + + {% for field_list in request_type.compute_field_lists() %} + for (int i = 0; i < 100; ++i ) { + const RequestHeader header = + { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation++, "id" }; + const {{ request_type.name }} data = { {{ field_list.example_value() }} }; + const RequestUnderTest request = {header, data}; + putInBuffer(request); + sent.push_back(request); + } + {% endfor %} + + const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); + const RequestParserResolver& request_parser_resolver = + RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = + std::make_shared(); + + RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; + + // when + testee.onData(buffer_); + + // then + const std::vector& received = request_callback->getCaptured(); + ASSERT_EQ(received.size(), sent.size()); + + for (size_t i = 0; i < received.size(); ++i) { + const std::shared_ptr request = + std::dynamic_pointer_cast(received[i]); + ASSERT_NE(request, nullptr); + ASSERT_EQ(*request, sent[i]); + } +} +{% endfor %} + +template +void RequestCodecRequestTest::putInBuffer(const T arg) { + RequestEncoder encoder{buffer_}; + encoder.encode(arg); +} + +} // namespace RequestCodecRequestTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 new file mode 100644 index 0000000000000..db14f9e2a55cf --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -0,0 +1,20 @@ +{# + Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest). + Rendered templates for each request in Kafka protocol will be put into 'requests.h' file. + + This template handles binding the top-level structure deserializer + (e.g. ProduceRequestV0Deserializer) with RequestDataParser. These parsers are then used by + RequestParserResolver instance depending on received Kafka api key & api version + (see 'kafka_request_resolver_cc.j2'). +#} + +{% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: + public RequestDataParser< + {{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> +{ +public: + {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : + RequestDataParser{ctx} {}; +}; + +{% endfor %} \ No newline at end of file diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 new file mode 100644 index 0000000000000..ff85d19410d07 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 @@ -0,0 +1,35 @@ +{# + Main template for 'requests.h' file. + Gets filled in (by 'contents') with Kafka request structures, deserializers, and parsers. + + For each request we have the following: + - 1 top-level structure corresponding to the request (e.g. `struct FetchRequest`), + - N deserializers for top-level structure, one for each request version, + - N parsers binding each deserializer with parser, + - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that are used by + the request's top-level structure, + - deserializers for each child structure. + + So for example, for FetchRequest we have: + - struct FetchRequest, + - FetchRequestV0Deserializer, FetchRequestV1Deserializer, FetchRequestV2Deserializer, etc., + - FetchRequestV0Parser, FetchRequestV1Parser, FetchRequestV2Parser, etc., + - struct FetchRequestTopic, + - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, etc. + (because topic data is present in every FetchRequest version), + - struct FetchRequestPartition, + - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, etc. + (because partition data is present in every FetchRequestTopic version). +#} +#pragma once +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +{{ contents }} + +}}}} diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 new file mode 100644 index 0000000000000..d7ec7ae98ca4f --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -0,0 +1,79 @@ +{# + Template for request serialization/deserialization tests. + For every request, we want to check if it can be serialized and deserialized properly. +#} + +#include "extensions/filters/network/kafka/external/requests.h" +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace RequestTest { + +class RequestTest : public testing::Test { +public: + Buffer::OwnedImpl buffer_; + + template std::shared_ptr serializeAndDeserialize(T request); +}; + +class MockMessageListener : public RequestCallback { +public: + MOCK_METHOD1(onMessage, void(AbstractRequestSharedPtr)); + MOCK_METHOD1(onFailedParse, void(RequestParseFailureSharedPtr)); +}; + +/** + * Helper method. + * Takes an instance of a request, serializes it, then deserializes it. + * This method gets executed for every request * version pair. + */ +template std::shared_ptr RequestTest::serializeAndDeserialize(T request) { + RequestEncoder encoder{buffer_}; + encoder.encode(request); + + std::shared_ptr mock_listener = std::make_shared(); + RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; + + AbstractRequestSharedPtr receivedMessage; + EXPECT_CALL(*mock_listener, onMessage(testing::_)) + .WillOnce(testing::SaveArg<0>(&receivedMessage)); + + testee.onData(buffer_); + + return std::dynamic_pointer_cast(receivedMessage); +}; + +{# + Concrete tests for each request_type and version (field_list). + Each request is naively constructed using some default values + (put "string" as std::string, 32 as uint32_t, etc.). +#} +{% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} +TEST_F(RequestTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { + // given + {{ request_type.name }} data = { {{ field_list.example_value() }} }; + Request<{{ request_type.name }}> request = { { + {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, 0, absl::nullopt }, data }; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} +{% endfor %}{% endfor %} + +} // namespace RequestTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc new file mode 100644 index 0000000000000..7b46c77a0f76a --- /dev/null +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -0,0 +1,96 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "common/buffer/buffer_impl.h" +#include "common/common/stack_array.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class RequestStartParserFactory : public InitialParserFactory { + RequestParserSharedPtr create(const RequestParserResolver& parser_resolver) const override { + return std::make_shared(parser_resolver); + } +}; + +const InitialParserFactory& InitialParserFactory::getDefaultInstance() { + CONSTRUCT_ON_FIRST_USE(RequestStartParserFactory); +} + +void RequestDecoder::onData(Buffer::Instance& data) { + // Convert buffer to slices and pass them to `doParse`. + uint64_t num_slices = data.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + data.getRawSlices(slices.begin(), num_slices); + for (const Buffer::RawSlice& slice : slices) { + doParse(slice); + } +} + +/** + * Main parse loop: + * - forward data to current parser, + * - receive parser response: + * -- if still waiting, do nothing (we wait for more data), + * -- if a parser is given, replace current parser with the new one, and it the rest of the data + * -- if a message is given: + * --- notify callbacks, + * --- replace current parser with new start parser, as we are going to start parsing the next + * message. + */ +void RequestDecoder::doParse(const Buffer::RawSlice& slice) { + const char* bytes = reinterpret_cast(slice.mem_); + absl::string_view data = {bytes, slice.len_}; + + while (!data.empty()) { + + // Feed the data to the parser. + RequestParseResponse result = current_parser_->parse(data); + // This loop guarantees that parsers consuming 0 bytes also get processed in this invocation. + while (result.hasData()) { + if (!result.next_parser_) { + + // Next parser is not present, so we have finished parsing a message. + // Depending on whether the parse was successful, invoke the correct callback. + if (result.message_) { + for (auto& callback : callbacks_) { + callback->onMessage(result.message_); + } + } else { + for (auto& callback : callbacks_) { + callback->onFailedParse(result.failure_data_); + } + } + + // As we finished parsing this request, re-initialize the parser. + current_parser_ = factory_.create(parser_resolver_); + } else { + + // The next parser that's supposed to consume the rest of payload was given. + current_parser_ = result.next_parser_; + } + + // Keep parsing the data. + result = current_parser_->parse(data); + } + } +} + +void RequestEncoder::encode(const AbstractRequest& message) { + Buffer::OwnedImpl data_buffer; + // TODO(adamkotwasinski) Precompute the size instead of using temporary buffer. + // When we have the 'computeSize' method, then we can push encoding request's size into + // Request::encode + int32_t data_len = message.encode(data_buffer); // Encode data and compute data length. + EncodingContext encoder{-1}; + encoder.encode(data_len, output_); // Encode data length into result. + output_.add(data_buffer); // Copy encoded data into result. +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h new file mode 100644 index 0000000000000..c8a6b69f87973 --- /dev/null +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -0,0 +1,127 @@ +#pragma once + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" + +#include "extensions/filters/network/kafka/codec.h" +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Callback invoked when request is successfully decoded. + */ +class RequestCallback { +public: + virtual ~RequestCallback() = default; + + /** + * Callback method invoked when request is successfully decoded. + * @param request request that has been decoded. + */ + virtual void onMessage(AbstractRequestSharedPtr request) PURE; + + /** + * Callback method invoked when request could not be decoded. + * Invoked after all request's bytes have been consumed. + */ + virtual void onFailedParse(RequestParseFailureSharedPtr failure_data) PURE; +}; + +typedef std::shared_ptr RequestCallbackSharedPtr; + +/** + * Provides initial parser for messages (class extracted to allow injecting test factories). + */ +class InitialParserFactory { +public: + virtual ~InitialParserFactory() = default; + + /** + * Creates default instance that returns RequestStartParser instances. + */ + static const InitialParserFactory& getDefaultInstance(); + + /** + * Creates parser with given context. + */ + virtual RequestParserSharedPtr create(const RequestParserResolver& parser_resolver) const PURE; +}; + +/** + * Decoder that decodes Kafka requests. + * When a request is decoded, the callbacks are notified, in order. + * + * This decoder uses chain of parsers to parse fragments of a request. + * Each parser along the line returns the fully parsed message or the next parser. + * Stores parse state (as large message's payload can be provided through multiple `onData` calls). + */ +class RequestDecoder : public MessageDecoder { +public: + /** + * Creates a decoder that can decode requests specified by RequestParserResolver, notifying + * callbacks on successful decoding. + * @param parserResolver supported parser resolver. + * @param callbacks callbacks to be invoked (in order). + */ + RequestDecoder(const RequestParserResolver& parserResolver, + const std::vector callbacks) + : RequestDecoder(InitialParserFactory::getDefaultInstance(), parserResolver, callbacks){}; + + /** + * Visible for testing. + * Allows injecting initial parser factory. + */ + RequestDecoder(const InitialParserFactory& factory, const RequestParserResolver& parserResolver, + const std::vector callbacks) + : factory_{factory}, parser_resolver_{parserResolver}, callbacks_{callbacks}, + current_parser_{factory_.create(parser_resolver_)} {}; + + /** + * Consumes all data present in a buffer. + * If a request can be successfully parsed, then callbacks get notified with parsed request. + * Updates decoder state. + * Impl note: similar to redis codec, which also keeps state. + */ + void onData(Buffer::Instance& data) override; + +private: + void doParse(const Buffer::RawSlice& slice); + + const InitialParserFactory& factory_; + + const RequestParserResolver& parser_resolver_; + + const std::vector callbacks_; + + RequestParserSharedPtr current_parser_; +}; + +/** + * Encodes requests into underlying buffer. + */ +class RequestEncoder : public MessageEncoder { +public: + /** + * Wraps buffer with encoder. + */ + RequestEncoder(Buffer::Instance& output) : output_(output) {} + + /** + * Encodes request into wrapped buffer. + */ + void encode(const AbstractRequest& message) override; + +private: + Buffer::Instance& output_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h new file mode 100644 index 0000000000000..38542c44d8761 --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization.h @@ -0,0 +1,767 @@ +#pragma once + +#include +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Deserializer is a stateful entity that constructs a result of type T from bytes provided. + * It can be feed()-ed data until it is ready, filling the internal store. + * When ready(), it is safe to call get() to transform the internally stored bytes into result. + * Further feed()-ing should have no effect on a buffer (should return 0 and not move + * provided pointer). + * @param T type of deserialized data. + */ +template class Deserializer { +public: + virtual ~Deserializer() = default; + + /** + * Submit data to be processed, will consume as much data as it is necessary. + * If any bytes are consumed, then the provided string view is updated by stepping over consumed + * bytes. Invoking this method when deserializer is ready has no effect (consumes 0 bytes). + * @param data bytes to be processed, will be updated if any have been consumed. + * @return number of bytes consumed (equal to change in 'data'). + */ + virtual size_t feed(absl::string_view& data) PURE; + + /** + * Whether deserializer has consumed enough data to return result. + */ + virtual bool ready() const PURE; + + /** + * Returns the entity that is represented by bytes stored in this deserializer. + * Should be only called when deserializer is ready. + */ + virtual T get() const PURE; +}; + +/** + * Generic integer deserializer (uses array of sizeof(T) bytes). + * After all bytes are filled in, the value is converted from network byte-order and returned. + */ +template class IntDeserializer : public Deserializer { +public: + IntDeserializer() : written_{0}, ready_(false){}; + + size_t feed(absl::string_view& data) override { + const size_t available = std::min(sizeof(buf_) - written_, data.size()); + memcpy(buf_ + written_, data.data(), available); + written_ += available; + + if (written_ == sizeof(buf_)) { + ready_ = true; + } + + data = {data.data() + available, data.size() - available}; + + return available; + } + + bool ready() const override { return ready_; } + +protected: + char buf_[sizeof(T) / sizeof(char)]; + size_t written_; + bool ready_{false}; +}; + +/** + * Integer deserializer for int8_t. + */ +class Int8Deserializer : public IntDeserializer { +public: + int8_t get() const override { + int8_t result; + memcpy(&result, buf_, sizeof(result)); + return result; + } +}; + +/** + * Integer deserializer for int16_t. + */ +class Int16Deserializer : public IntDeserializer { +public: + int16_t get() const override { + int16_t result; + memcpy(&result, buf_, sizeof(result)); + return be16toh(result); + } +}; + +/** + * Integer deserializer for int32_t. + */ +class Int32Deserializer : public IntDeserializer { +public: + int32_t get() const override { + int32_t result; + memcpy(&result, buf_, sizeof(result)); + return be32toh(result); + } +}; + +/** + * Integer deserializer for uint32_t. + */ +class UInt32Deserializer : public IntDeserializer { +public: + uint32_t get() const override { + uint32_t result; + memcpy(&result, buf_, sizeof(result)); + return be32toh(result); + } +}; + +/** + * Integer deserializer for uint64_t. + */ +class Int64Deserializer : public IntDeserializer { +public: + int64_t get() const override { + int64_t result; + memcpy(&result, buf_, sizeof(result)); + return be64toh(result); + } +}; + +/** + * Deserializer for boolean values. + * Uses a single int8 deserializer, and checks whether the results equals 0. + * When reading a boolean value, any non-zero value is considered true. + * Impl note: could have been a subclass of IntDeserializer with a different get function, + * but it makes it harder to understand. + */ +class BooleanDeserializer : public Deserializer { +public: + BooleanDeserializer(){}; + + size_t feed(absl::string_view& data) override { return buffer_.feed(data); } + + bool ready() const override { return buffer_.ready(); } + + bool get() const override { return 0 != buffer_.get(); } + +private: + Int8Deserializer buffer_; +}; + +/** + * Deserializer of string value. + * First reads length (INT16) and then allocates the buffer of given length. + * + * From Kafka documentation: + * First the length N is given as an INT16. + * Then N bytes follow which are the UTF-8 encoding of the character sequence. + * Length must not be negative. + */ +class StringDeserializer : public Deserializer { +public: + /** + * Can throw EnvoyException if given string length is not valid. + */ + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid STRING length: {}", required_)); + } + length_consumed_ = true; + } + + const size_t data_consumed = std::min(required_, data.size()); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, data.data(), data_consumed); + required_ -= data_consumed; + + data = {data.data() + data_consumed, data.size() - data_consumed}; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const override { return ready_; } + + std::string get() const override { return std::string(data_buf_.begin(), data_buf_.end()); } + +private: + Int16Deserializer length_buf_; + bool length_consumed_{false}; + + int16_t required_; + std::vector data_buf_; + + bool ready_{false}; +}; + +/** + * Deserializer of nullable string value. + * First reads length (INT16) and then allocates the buffer of given length. + * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning + * null value). + * + * From Kafka documentation: + * For non-null strings, first the length N is given as an INT16. + * Then N bytes follow which are the UTF-8 encoding of the character sequence. + * A null value is encoded with length of -1 and there are no following bytes. + */ +class NullableStringDeserializer : public Deserializer { +public: + /** + * Can throw EnvoyException if given string length is not valid. + */ + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } + if (required_ == NULL_STRING_LENGTH) { + ready_ = true; + } + if (required_ < NULL_STRING_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_STRING length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, data.size()); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, data.data(), data_consumed); + required_ -= data_consumed; + + data = {data.data() + data_consumed, data.size() - data_consumed}; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const override { return ready_; } + + NullableString get() const override { + return required_ >= 0 ? absl::make_optional(std::string(data_buf_.begin(), data_buf_.end())) + : absl::nullopt; + } + +private: + constexpr static int16_t NULL_STRING_LENGTH{-1}; + + Int16Deserializer length_buf_; + bool length_consumed_{false}; + + int16_t required_; + std::vector data_buf_; + + bool ready_{false}; +}; + +/** + * Deserializer of bytes value. + * First reads length (INT32) and then allocates the buffer of given length. + * + * From Kafka documentation: + * First the length N is given as an INT32. Then N bytes follow. + */ +class BytesDeserializer : public Deserializer { +public: + /** + * Can throw EnvoyException if given bytes length is not valid. + */ + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid BYTES length: {}", required_)); + } + length_consumed_ = true; + } + + const size_t data_consumed = std::min(required_, data.size()); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, data.data(), data_consumed); + required_ -= data_consumed; + + data = {data.data() + data_consumed, data.size() - data_consumed}; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const override { return ready_; } + + Bytes get() const override { return data_buf_; } + +private: + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + + std::vector data_buf_; + bool ready_{false}; +}; + +/** + * Deserializer of nullable bytes value. + * First reads length (INT32) and then allocates the buffer of given length. + * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning + * null value). + * + * From Kafka documentation: + * For non-null values, first the length N is given as an INT32. Then N bytes follow. + * A null value is encoded with length of -1 and there are no following bytes. + */ +class NullableBytesDeserializer : public Deserializer { +public: + /** + * Can throw EnvoyException if given bytes length is not valid. + */ + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } + if (required_ == NULL_BYTES_LENGTH) { + ready_ = true; + } + if (required_ < NULL_BYTES_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, data.size()); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, data.data(), data_consumed); + required_ -= data_consumed; + + data = {data.data() + data_consumed, data.size() - data_consumed}; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const override { return ready_; } + + NullableBytes get() const override { + if (NULL_BYTES_LENGTH == required_) { + return absl::nullopt; + } else { + return {data_buf_}; + } + } + +private: + constexpr static int32_t NULL_BYTES_LENGTH{-1}; + + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + + std::vector data_buf_; + bool ready_{false}; +}; + +/** + * Deserializer for array of objects of the same type. + * + * First reads the length of the array, then initializes N underlying deserializers of type + * DeserializerType. After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector. + * @param ResponseType result type returned by deserializer of type DeserializerType. + * @param DeserializerType underlying deserializer type. + * + * From Kafka documentation: + * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. + * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T + * follow. A null array is represented with a length of -1. + */ +template +class ArrayDeserializer : public Deserializer> { +public: + /** + * Can throw EnvoyException if array length is invalid or if underlying deserializer can throw. + */ + size_t feed(absl::string_view& data) override { + + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + children_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid ARRAY length: {}", required_)); + } + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + size_t child_consumed{0}; + for (DeserializerType& child : children_) { + child_consumed += child.feed(data); + } + + bool children_ready_ = true; + for (DeserializerType& child : children_) { + children_ready_ &= child.ready(); + } + ready_ = children_ready_; + + return length_consumed + child_consumed; + } + + bool ready() const override { return ready_; } + + std::vector get() const override { + std::vector result{}; + result.reserve(children_.size()); + for (const DeserializerType& child : children_) { + const ResponseType child_result = child.get(); + result.push_back(child_result); + } + return result; + } + +private: + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + std::vector children_; + bool children_setup_{false}; + bool ready_{false}; +}; + +/** + * Deserializer for nullable array of objects of the same type. + * + * First reads the length of the array, then initializes N underlying deserializers of type + * DeserializerType. After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector. + * @param ResponseType result type returned by deserializer of type DeserializerType. + * @param DeserializerType underlying deserializer type. + * + * From Kafka documentation: + * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. + * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T + * follow. A null array is represented with a length of -1. + */ +template +class NullableArrayDeserializer : public Deserializer> { +public: + /** + * Can throw EnvoyException if array length is invalid or if underlying deserializer can throw. + */ + size_t feed(absl::string_view& data) override { + + const size_t length_consumed = length_buf_.feed(data); + if (!length_buf_.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + children_ = std::vector(required_); + } + if (required_ == NULL_ARRAY_LENGTH) { + ready_ = true; + } + if (required_ < NULL_ARRAY_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_ARRAY length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + size_t child_consumed{0}; + for (DeserializerType& child : children_) { + child_consumed += child.feed(data); + } + + bool children_ready_ = true; + for (DeserializerType& child : children_) { + children_ready_ &= child.ready(); + } + ready_ = children_ready_; + + return length_consumed + child_consumed; + } + + bool ready() const override { return ready_; } + + NullableArray get() const override { + if (NULL_ARRAY_LENGTH != required_) { + std::vector result{}; + result.reserve(children_.size()); + for (const DeserializerType& child : children_) { + const ResponseType child_result = child.get(); + result.push_back(child_result); + } + return result; + } else { + return absl::nullopt; + } + } + +private: + constexpr static int32_t NULL_ARRAY_LENGTH{-1}; + + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + std::vector children_; + bool children_setup_{false}; + bool ready_{false}; +}; + +/** + * Encodes provided argument in Kafka format. + * In case of primitive types, this is done explicitly as per specification. + * In case of composite types, this is done by calling 'encode' on provided argument. + * + * This object also carries extra information that is used while traversing the request + * structure-tree during encoding (currently api_version, as different request versions serialize + * differently). + */ +// TODO(adamkotwasinski) that class might be split into Request/ResponseEncodingContext in future +class EncodingContext { +public: + EncodingContext(int16_t api_version) : api_version_{api_version} {}; + + /** + * Encode given reference in a buffer. + * @return bytes written + */ + template size_t encode(const T& arg, Buffer::Instance& dst); + + /** + * Encode given array in a buffer. + * @return bytes written + */ + template size_t encode(const std::vector& arg, Buffer::Instance& dst); + + /** + * Encode given nullable array in a buffer. + * @return bytes written + */ + template size_t encode(const NullableArray& arg, Buffer::Instance& dst); + + int16_t apiVersion() const { return api_version_; } + +private: + const int16_t api_version_; +}; + +/** + * For non-primitive types, call `encode` on them, to delegate the serialization to the entity + * itself. + */ +template inline size_t EncodingContext::encode(const T& arg, Buffer::Instance& dst) { + return arg.encode(dst, *this); +} + +/** + * Template overload for int8_t. + * Encode a single byte. + */ +template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Instance& dst) { + dst.add(&arg, sizeof(int8_t)); + return sizeof(int8_t); +} + +/** + * Template overload for int16_t, int32_t, uint32_t, int64_t. + * Encode a N-byte integer, converting to network byte-order. + */ +#define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ + template <> inline size_t EncodingContext::encode(const TYPE& arg, Buffer::Instance& dst) { \ + const TYPE val = CONVERTER(arg); \ + dst.add(&val, sizeof(TYPE)); \ + return sizeof(TYPE); \ + } + +ENCODE_NUMERIC_TYPE(int16_t, htobe16); +ENCODE_NUMERIC_TYPE(int32_t, htobe32); +ENCODE_NUMERIC_TYPE(uint32_t, htobe32); +ENCODE_NUMERIC_TYPE(int64_t, htobe64); + +/** + * Template overload for bool. + * Encode boolean as a single byte. + */ +template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Instance& dst) { + int8_t val = arg; + dst.add(&val, sizeof(int8_t)); + return sizeof(int8_t); +} + +/** + * Template overload for std::string. + * Encode string as INT16 length + N bytes. + */ +template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer::Instance& dst) { + int16_t string_length = arg.length(); + size_t header_length = encode(string_length, dst); + dst.add(arg.c_str(), string_length); + return header_length + string_length; +} + +/** + * Template overload for NullableString. + * Encode nullable string as INT16 length + N bytes (length = -1 for null). + */ +template <> +inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + return encode(*arg, dst); + } else { + const int16_t len = -1; + return encode(len, dst); + } +} + +/** + * Template overload for Bytes. + * Encode byte array as INT32 length + N bytes. + */ +template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Instance& dst) { + const int32_t data_length = arg.size(); + const size_t header_length = encode(data_length, dst); + dst.add(arg.data(), arg.size()); + return header_length + data_length; +} + +/** + * Template overload for NullableBytes. + * Encode nullable byte array as INT32 length + N bytes (length = -1 for null value). + */ +template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + return encode(*arg, dst); + } else { + const int32_t len = -1; + return encode(len, dst); + } +} + +/** + * Encode nullable object array to T as INT32 length + N elements. + * Each element of type T then serializes itself on its own. + */ +template +size_t EncodingContext::encode(const std::vector& arg, Buffer::Instance& dst) { + const NullableArray wrapped = {arg}; + return encode(wrapped, dst); +} + +/** + * Encode nullable object array to T as INT32 length + N elements (length = -1 for null value). + * Each element of type T then serializes itself on its own. + */ +template +size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + const int32_t len = arg->size(); + const size_t header_length = encode(len, dst); + size_t written{0}; + for (const T& el : *arg) { + // For each of array elements, resolve the correct method again. + // Elements could be primitives or complex types, so calling encode() on object won't work. + written += encode(el, dst); + } + return header_length + written; + } else { + const int32_t len = -1; + return encode(len, dst); + } +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py new file mode 100755 index 0000000000000..100bf7593c71e --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py @@ -0,0 +1,78 @@ +#!/usr/bin/python + + +def main(): + """ + Serialization composite generator script + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Generates main&test source code files for composite deserializers. + The files are generated, as they are extremely repetitive (composite deserializer for 0..9 + sub-deserializers). + + Usage: + serialization_composite_generator.py COMMAND LOCATION_OF_OUTPUT_FILE + where: + COMMAND : 'generate-source', to generate source files, + 'generate-test', to generate test files. + LOCATION_OF_OUTPUT_FILE : if generate-source: location of 'serialization_composite.h', + if generate-test: location of 'serialization_composite_test.cc'. + + When generating source code, it creates: + - serialization_composite.h - header with declarations of CompositeDeserializerWith???Delegates + classes. + When generating test code, it creates: + - serialization_composite_test.cc - tests for these classes. + + Templates used are: + - to create 'serialization_composite.h': serialization_composite_h.j2, + - to create 'serialization_composite_test.cc': serialization_composite_test_cc.j2. + """ + + import sys + import os + + command = sys.argv[1] + if 'generate-source' == command: + serialization_composite_h_file = os.path.abspath(sys.argv[2]) + elif 'generate-test' == command: + serialization_composite_test_cc_file = os.path.abspath(sys.argv[2]) + else: + raise ValueError('invalid command: ' + command) + + import re + import json + + # Number of fields deserialized by each deserializer class. + field_counts = range(1, 10) + + # Generate main source code. + if 'generate-source' == command: + template = RenderingHelper.get_template('serialization_composite_h.j2') + contents = template.render(counts=field_counts) + with open(serialization_composite_h_file, 'w') as fd: + fd.write(contents) + + # Generate test code. + if 'generate-test' == command: + template = RenderingHelper.get_template('serialization_composite_test_cc.j2') + contents = template.render(counts=field_counts) + with open(serialization_composite_test_cc_file, 'w') as fd: + fd.write(contents) + + +class RenderingHelper: + """ + Helper for jinja templates. + """ + + @staticmethod + def get_template(template): + import jinja2 + import os + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(searchpath=os.path.dirname(os.path.abspath(__file__)))) + return env.get_template(template) + + +if __name__ == "__main__": + main() diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 new file mode 100644 index 0000000000000..5b88fa9538bed --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 @@ -0,0 +1,100 @@ +{# + Creates 'serialization_composite.h'. + + Template for composite serializers (the CompositeDeserializerWith_N_Delegates classes). + Covers the corner case of 0 delegates, and then uses templating to create declarations for 1..N + variants. +#} +#pragma once + +#include +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" +#include "extensions/filters/network/kafka/serialization.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * This header contains only composite deserializers. + * The basic design is composite deserializer creating delegates DeserializerType1..N. + * Result of type ResponseType is constructed by getting results of each of delegates. + * These deserializers can throw, if any of the delegate deserializers can. + */ + +/** + * Composite deserializer that uses 0 deserializer(s) (corner case). + * Does not consume any bytes, and is always ready to return the result. + * Creates a result value using the no-arg ResponseType constructor. + * @param ResponseType type of deserialized data. + */ +template +class CompositeDeserializerWith0Delegates : public Deserializer { +public: + CompositeDeserializerWith0Delegates(){}; + size_t feed(absl::string_view&) override { return 0; } + bool ready() const override { return true; } + ResponseType get() const override { return {}; } +}; + +{% for field_count in counts %} +/** + * Composite deserializer that uses {{ field_count }} deserializer(s). + * Passes data to each of the underlying deserializers (deserializers that are already ready do not + * consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready (what means that all + * deserializers before it are ready too). + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... }. + * + * @param ResponseType type of deserialized data{% for field in range(1, field_count + 1) %}. + * @param DeserializerType{{ field }} deserializer {{ field }}. +{% endfor %} */ +template < + typename ResponseType{% for field in range(1, field_count + 1) %}, + typename DeserializerType{{ field }}{% endfor %} +> +class CompositeDeserializerWith{{ field_count }}Delegates : public Deserializer { +public: + CompositeDeserializerWith{{ field_count }}Delegates(){}; + + size_t feed(absl::string_view& data) override { + size_t consumed = 0; + {% for field in range(1, field_count + 1) %} + consumed += delegate{{ field }}_.feed(data); + {% endfor %} + return consumed; + } + + bool ready() const override { return delegate{{ field_count }}_.ready(); } + + ResponseType get() const override { + return { + {% for field in range(1, field_count + 1) %}delegate{{ field }}_.get(), + {% endfor %}}; + } + +protected: + {% for field in range(1, field_count + 1) %} + DeserializerType{{ field }} delegate{{ field }}_; + {% endfor %} +}; +{% endfor %} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 new file mode 100644 index 0000000000000..ee6204171eb64 --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 @@ -0,0 +1,90 @@ +{# + Creates 'serialization_composite_test.cc'. + + Template for composite serializer tests (the CompositeDeserializerWith_N_Delegates classes). + Covers the corner case of 0 delegates, and then uses templating to create tests for 1..N cases. +#} + +#include "extensions/filters/network/kafka/external/serialization_composite.h" + +#include "test/extensions/filters/network/kafka/serialization_utilities.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Tests in this class are supposed to check whether serialization operations on composite + * deserializers are correct. + */ + +// Tests for composite deserializer with 0 fields (corner case). + +struct CompositeResultWith0Fields { + size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } + bool operator==(const CompositeResultWith0Fields&) const { return true; } +}; + +typedef CompositeDeserializerWith0Delegates TestCompositeDeserializer0; + +// Composite with 0 delegates is special case: it's always ready. +TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { + // given + const TestCompositeDeserializer0 testee{}; + // when, then + ASSERT_EQ(testee.ready(), true); +} + +TEST(CompositeDeserializerWith0Delegates, ShouldDeserialize) { + const CompositeResultWith0Fields expected{}; + serializeThenDeserializeAndCheckEquality(expected); +} + +// Tests for composite deserializer with N+ fields. + +{% for field_count in counts %} +struct CompositeResultWith{{ field_count }}Fields { + {% for field in range(1, field_count + 1) %} + const std::string field{{ field }}_; + {% endfor %} + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + {% for field in range(1, field_count + 1) %} + written += encoder.encode(field{{ field }}_, dst); + {% endfor %} + return written; + } + + bool operator==(const CompositeResultWith{{ field_count }}Fields& rhs) const { + return true + {% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_ + {% endfor %}; + } +}; + +typedef CompositeDeserializerWith{{ field_count }}Delegates< + CompositeResultWith{{ field_count }}Fields + {% for field in range(1, field_count + 1) %}, StringDeserializer{% endfor %} +> TestCompositeDeserializer{{ field_count }}; + +TEST(CompositeDeserializerWith{{ field_count }}Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer{{ field_count }} testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith{{ field_count }}Delegates, ShouldDeserialize) { + const CompositeResultWith{{ field_count }}Fields expected{ + {% for field in range(1, field_count + 1) %}"s{{ field }}", {% endfor %} + }; + serializeThenDeserializeAndCheckEquality(expected); +} +{% endfor %} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD new file mode 100644 index 0000000000000..9bb39da434071 --- /dev/null +++ b/test/extensions/filters/network/kafka/BUILD @@ -0,0 +1,131 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_test_library", + "envoy_package", +) +load( + "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_test", +) + +envoy_package() + +envoy_cc_test_library( + name = "serialization_utilities_lib", + srcs = ["serialization_utilities.cc"], + hdrs = ["serialization_utilities.h"], + deps = [ + "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//source/extensions/filters/network/kafka:serialization_lib", + ], +) + +envoy_extension_cc_test( + name = "serialization_test", + srcs = ["serialization_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + ":serialization_utilities_lib", + "//source/extensions/filters/network/kafka:serialization_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "serialization_composite_test", + srcs = ["external/serialization_composite_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + ":serialization_utilities_lib", + "//source/extensions/filters/network/kafka:serialization_lib", + "//test/mocks/server:server_mocks", + ], +) + +genrule( + name = "serialization_composite_test_generator", + srcs = [], + outs = ["external/serialization_composite_test.cc"], + cmd = """ + ./$(location //source/extensions/filters/network/kafka:serialization_composite_generator) \ + generate-test $(location external/serialization_composite_test.cc) + """, + tools = [ + "//source/extensions/filters/network/kafka:serialization_composite_generator", + ], +) + +envoy_extension_cc_test( + name = "kafka_request_parser_test", + srcs = ["kafka_request_parser_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + ":serialization_utilities_lib", + "//source/extensions/filters/network/kafka:kafka_request_parser_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_unit_test", + srcs = ["request_codec_unit_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_integration_test", + srcs = ["request_codec_integration_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + ":serialization_utilities_lib", + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_request_test", + srcs = ["external/request_codec_request_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + ":serialization_utilities_lib", + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "requests_test", + srcs = ["external/requests_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +genrule( + name = "requests_test_generator", + srcs = [ + "@kafka_source//:request_protocol_files", + ], + outs = [ + "external/requests_test.cc", + "external/request_codec_request_test.cc", + ], + cmd = """ + ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ + $(location external/requests_test.cc) $(location external/request_codec_request_test.cc) \ + $(SRCS) + """, + tools = [ + "//source/extensions/filters/network/kafka:kafka_code_generator", + ], +) diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc new file mode 100644 index 0000000000000..e7163e0fab865 --- /dev/null +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -0,0 +1,263 @@ +#include "extensions/filters/network/kafka/kafka_request_parser.h" + +#include "test/extensions/filters/network/kafka/serialization_utilities.h" +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" + +using testing::_; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace KafkaRequestParserTest { + +const int32_t FAILED_DESERIALIZER_STEP = 13; + +class KafkaRequestParserTest : public testing::Test { +public: + const char* getBytes() { + uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer_.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); + } + + template size_t putIntoBuffer(const T arg) { + EncodingContext encoder_{-1}; // Context's api_version is not used when serializing primitives. + return encoder_.encode(arg, buffer_); + } + + absl::string_view putGarbageIntoBuffer(size_t size = 10000) { + putIntoBuffer(Bytes(size)); + return {getBytes(), size}; + } + +protected: + Buffer::OwnedImpl buffer_; +}; + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver(){}; + MOCK_CONST_METHOD3(createParser, + RequestParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); +}; + +TEST_F(KafkaRequestParserTest, RequestStartParserTestShouldReturnRequestHeaderParser) { + // given + MockRequestParserResolver resolver{}; + RequestStartParser testee{resolver}; + + int32_t request_len = 1234; + putIntoBuffer(request_len); + + const absl::string_view orig_data = {getBytes(), 1024}; + absl::string_view data = orig_data; + + // when + const RequestParseResponse result = testee.parse(data); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); + assertStringViewIncrement(data, orig_data, sizeof(int32_t)); +} + +class MockParser : public RequestParser { +public: + RequestParseResponse parse(absl::string_view&) override { + throw new EnvoyException("should not be invoked"); + } +}; + +TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldExtractHeaderAndResolveNextParser) { + // given + const MockRequestParserResolver parser_resolver; + const RequestParserSharedPtr parser{new MockParser{}}; + EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); + + const int32_t request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + RequestHeaderParser testee{parser_resolver, context}; + + const int16_t api_key{1}; + const int16_t api_version{2}; + const int32_t correlation_id{10}; + const NullableString client_id{"aaa"}; + size_t header_len = 0; + header_len += putIntoBuffer(api_key); + header_len += putIntoBuffer(api_version); + header_len += putIntoBuffer(correlation_id); + header_len += putIntoBuffer(client_id); + + const absl::string_view orig_data = putGarbageIntoBuffer(); + absl::string_view data = orig_data; + + // when + const RequestParseResponse result = testee.parse(data); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, parser); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); + + const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; + ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - header_len); + + assertStringViewIncrement(data, orig_data, header_len); +} + +TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldHandleExceptionsDuringFeeding) { + // given + + // This deserializer throws during feeding. + class ThrowingRequestHeaderDeserializer : public RequestHeaderDeserializer { + public: + size_t feed(absl::string_view& data) override { + // Move some pointers to simulate data consumption. + data = {data.data() + FAILED_DESERIALIZER_STEP, data.size() - FAILED_DESERIALIZER_STEP}; + throw EnvoyException("feed"); + }; + + bool ready() const override { throw std::runtime_error("should not be invoked at all"); }; + + RequestHeader get() const override { + throw std::runtime_error("should not be invoked at all"); + }; + }; + + const MockRequestParserResolver parser_resolver; + + const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. + RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + RequestHeaderParser testee{parser_resolver, request_context, + std::make_unique()}; + + const absl::string_view orig_data = putGarbageIntoBuffer(); + absl::string_view data = orig_data; + + // when + const RequestParseResponse result = testee.parse(data); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, + request_size - FAILED_DESERIALIZER_STEP); + + assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); +} + +TEST_F(KafkaRequestParserTest, RequestDataParserShouldHandleDeserializerExceptionsDuringFeeding) { + // given + + // This deserializer throws during feeding. + class ThrowingDeserializer : public Deserializer { + public: + size_t feed(absl::string_view&) override { + // Move some pointers to simulate data consumption. + throw EnvoyException("feed"); + }; + + bool ready() const override { throw std::runtime_error("should not be invoked at all"); }; + + int32_t get() const override { throw std::runtime_error("should not be invoked at all"); }; + }; + + RequestContextSharedPtr request_context{new RequestContext{1024, {}}}; + RequestDataParser testee{request_context}; + + absl::string_view data = putGarbageIntoBuffer(); + + // when + bool caught = false; + try { + testee.parse(data); + } catch (EnvoyException& e) { + caught = true; + } + + // then + ASSERT_EQ(caught, true); +} + +// This deserializer consumes FAILED_DESERIALIZER_STEP bytes and returns 0 +class SomeBytesDeserializer : public Deserializer { +public: + size_t feed(absl::string_view& data) override { + data = {data.data() + FAILED_DESERIALIZER_STEP, data.size() - FAILED_DESERIALIZER_STEP}; + return FAILED_DESERIALIZER_STEP; + }; + + bool ready() const override { return true; }; + + int32_t get() const override { return 0; }; +}; + +TEST_F(KafkaRequestParserTest, + RequestDataParserShouldHandleDeserializerReturningReadyButLeavingData) { + // given + const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. + RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + + RequestDataParser testee{request_context}; + + const absl::string_view orig_data = putGarbageIntoBuffer(); + absl::string_view data = orig_data; + + // when + const RequestParseResponse result = testee.parse(data); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, + request_size - FAILED_DESERIALIZER_STEP); + + assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); +} + +TEST_F(KafkaRequestParserTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { + // given + const int32_t request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + SentinelParser testee{context}; + + const absl::string_view orig_data = putGarbageIntoBuffer(request_len * 2); + absl::string_view data = orig_data; + + // when + const RequestParseResponse result = testee.parse(data); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.failure_data_), nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); + + assertStringViewIncrement(data, orig_data, request_len); +} + +} // namespace KafkaRequestParserTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc new file mode 100644 index 0000000000000..6d087f310466a --- /dev/null +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -0,0 +1,72 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/extensions/filters/network/kafka/serialization_utilities.h" +#include "test/mocks/server/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace RequestCodecIntegrationTest { + +class RequestCodecIntegrationTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; +}; + +// Other request types are tested in (generated) 'request_codec_request_test.cc'. +TEST_F(RequestCodecIntegrationTest, shouldProduceAbortedMessageOnUnknownData) { + // given + // As real api keys have values below 100, the messages generated in this loop should not be + // recognized by the codec. + const int16_t base_api_key = 100; + std::vector sent_headers; + for (int16_t i = 0; i < 1000; ++i) { + const int16_t api_key = static_cast(base_api_key + i); + const RequestHeader header = {api_key, 0, 0, "client-id"}; + const std::vector data = std::vector(1024); + putInBuffer(Request>{header, data}); + sent_headers.push_back(header); + } + + const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); + const RequestParserResolver& request_parser_resolver = + RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = + std::make_shared(); + + RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; + + // when + testee.onData(buffer_); + + // then + ASSERT_EQ(request_callback->getCaptured().size(), 0); + + const std::vector& parse_failures = + request_callback->getParseFailures(); + ASSERT_EQ(parse_failures.size(), sent_headers.size()); + + for (size_t i = 0; i < parse_failures.size(); ++i) { + const std::shared_ptr failure_data = + std::dynamic_pointer_cast(parse_failures[i]); + ASSERT_NE(failure_data, nullptr); + ASSERT_EQ(failure_data->request_header_, sent_headers[i]); + } +} + +// Helper function. +template void RequestCodecIntegrationTest::putInBuffer(T arg) { + RequestEncoder encoder{buffer_}; + encoder.encode(arg); +} + +} // namespace RequestCodecIntegrationTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc new file mode 100644 index 0000000000000..3685f019759e3 --- /dev/null +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -0,0 +1,211 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::AnyNumber; +using testing::Eq; +using testing::Invoke; +using testing::ResultOf; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace RequestCodecUnitTest { + +class MockParserFactory : public InitialParserFactory { +public: + MOCK_CONST_METHOD1(create, RequestParserSharedPtr(const RequestParserResolver&)); +}; + +class MockParser : public RequestParser { +public: + MOCK_METHOD1(parse, RequestParseResponse(absl::string_view&)); +}; + +typedef std::shared_ptr MockParserSharedPtr; + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver() : RequestParserResolver({}){}; + MOCK_CONST_METHOD3(createParser, + RequestParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); +}; + +class MockRequestCallback : public RequestCallback { +public: + MOCK_METHOD1(onMessage, void(AbstractRequestSharedPtr)); + MOCK_METHOD1(onFailedParse, void(RequestParseFailureSharedPtr)); +}; + +typedef std::shared_ptr MockRequestCallbackSharedPtr; + +class RequestCodecUnitTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; + + MockParserFactory initial_parser_factory_{}; + MockRequestParserResolver parser_resolver_{}; + MockRequestCallbackSharedPtr request_callback_{std::make_shared()}; +}; + +RequestParseResponse consumeOneByte(absl::string_view& data) { + data = {data.data() + 1, data.size() - 1}; + return RequestParseResponse::stillWaiting(); +} + +TEST_F(RequestCodecUnitTest, shouldDoNothingIfParserNeverReturnsMessage) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser = std::make_shared(); + EXPECT_CALL(*parser, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)).WillOnce(Return(parser)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // There were no interactions with `request_callback`. +} + +TEST_F(RequestCodecUnitTest, shouldUseNewParserAsResponse) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + MockParserSharedPtr parser2 = std::make_shared(); + MockParserSharedPtr parser3 = std::make_shared(); + EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(RequestParseResponse::nextParser(parser2))); + EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(RequestParseResponse::nextParser(parser3))); + EXPECT_CALL(*parser3, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)).WillOnce(Return(parser1)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // There were no interactions with `request_callback`. +} + +TEST_F(RequestCodecUnitTest, shouldPassParsedMessageToCallbackAndReinitialize) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader()); + EXPECT_CALL(*parser1, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); + + MockParserSharedPtr parser2 = std::make_shared(); + EXPECT_CALL(*parser2, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser2)); + + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // There was only one message sent to `request_callback`. +} + +TEST_F(RequestCodecUnitTest, shouldPassParseFailureDataToCallbackAndReinitialize) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader()); + EXPECT_CALL(*parser1, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); + + MockParserSharedPtr parser2 = std::make_shared(); + EXPECT_CALL(*parser2, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser2)); + + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // `request_callback` had `onFailedParse` invoked once with matching argument. +} + +TEST_F(RequestCodecUnitTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + MockParserSharedPtr parser2 = std::make_shared(); + MockParserSharedPtr parser3 = std::make_shared(); + + // parser1 consumes buffer_.length() bytes (== everything) and returns parser2 + auto consume_and_return = [this, &parser2](absl::string_view& data) -> RequestParseResponse { + data = {data.data() + buffer_.length(), data.size() - buffer_.length()}; + return RequestParseResponse::nextParser(parser2); + }; + EXPECT_CALL(*parser1, parse(_)).WillOnce(Invoke(consume_and_return)); + + // parser2 just returns parse result + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader{}); + EXPECT_CALL(*parser2, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); + + // parser3 just consumes everything + EXPECT_CALL(*parser3, parse(ResultOf([](absl::string_view arg) { return arg.size(); }, Eq(0)))) + .WillOnce(Return(RequestParseResponse::stillWaiting())); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser3)); + + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // `request_callback` was invoked only once. + // After that, `parser3` was created and passed remaining data (that should have been empty). +} + +// Helper function. +template void RequestCodecUnitTest::putInBuffer(T arg) { + RequestEncoder encoder{buffer_}; + encoder.encode(arg); +} + +} // namespace RequestCodecUnitTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc new file mode 100644 index 0000000000000..d52cdff25fe13 --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -0,0 +1,229 @@ +#include "test/extensions/filters/network/kafka/serialization_utilities.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { +namespace SerializationTest { + +/** + * Tests in this file are supposed to check whether serialization operations + * on Kafka-primitive types (ints, strings, arrays) are behaving correctly. + */ + +// Freshly created deserializers should not be ready. +#define TEST_EmptyDeserializerShouldNotBeReady(DeserializerClass) \ + TEST(DeserializerClass, EmptyBufferShouldNotBeReady) { \ + const DeserializerClass testee{}; \ + ASSERT_EQ(testee.ready(), false); \ + } + +TEST_EmptyDeserializerShouldNotBeReady(Int8Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int16Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int32Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(UInt32Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int64Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(BooleanDeserializer); + +TEST_EmptyDeserializerShouldNotBeReady(StringDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(NullableStringDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(BytesDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(NullableBytesDeserializer); + +TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { + // given + const ArrayDeserializer testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(NullableArrayDeserializer, EmptyBufferShouldNotBeReady) { + // given + const NullableArrayDeserializer testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +// Extracted test for numeric buffers. +#define TEST_DeserializerShouldDeserialize(BufferClass, DataClass, Value) \ + TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ + /* given */ \ + const DataClass value = Value; \ + serializeThenDeserializeAndCheckEquality(value); \ + } + +TEST_DeserializerShouldDeserialize(Int8Deserializer, int8_t, 42); +TEST_DeserializerShouldDeserialize(Int16Deserializer, int16_t, 42); +TEST_DeserializerShouldDeserialize(Int32Deserializer, int32_t, 42); +TEST_DeserializerShouldDeserialize(UInt32Deserializer, uint32_t, 42); +TEST_DeserializerShouldDeserialize(Int64Deserializer, int64_t, 42); +TEST_DeserializerShouldDeserialize(BooleanDeserializer, bool, true); + +EncodingContext encoder{-1}; // Provided api_version does not matter for primitive types. + +TEST(StringDeserializer, ShouldDeserialize) { + const std::string value = "sometext"; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(StringDeserializer, ShouldDeserializeEmptyString) { + const std::string value = ""; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(StringDeserializer, ShouldThrowOnInvalidLength) { + // given + StringDeserializer testee; + Buffer::OwnedImpl buffer; + + int16_t len = -1; // STRING accepts length >= 0. + encoder.encode(len, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +TEST(NullableStringDeserializer, ShouldDeserializeString) { + // given + const NullableString value{"sometext"}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringDeserializer, ShouldDeserializeEmptyString) { + // given + const NullableString value{""}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringDeserializer, ShouldDeserializeAbsentString) { + // given + const NullableString value = absl::nullopt; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringDeserializer, ShouldThrowOnInvalidLength) { + // given + NullableStringDeserializer testee; + Buffer::OwnedImpl buffer; + + int16_t len = -2; // -1 is OK for NULLABLE_STRING. + encoder.encode(len, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +TEST(BytesDeserializer, ShouldDeserialize) { + const Bytes value{'a', 'b', 'c', 'd'}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(BytesDeserializer, ShouldDeserializeEmptyBytes) { + const Bytes value{}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(BytesDeserializer, ShouldThrowOnInvalidLength) { + // given + BytesDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t bytes_length = -1; // BYTES accepts length >= 0. + encoder.encode(bytes_length, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +TEST(NullableBytesDeserializer, ShouldDeserialize) { + const NullableBytes value{{'a', 'b', 'c', 'd'}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldDeserializeEmptyBytes) { + const NullableBytes value{{}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldDeserializeNullBytes) { + const NullableBytes value = absl::nullopt; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldThrowOnInvalidLength) { + // given + NullableBytesDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t bytes_length = -2; // -1 is OK for NULLABLE_BYTES. + encoder.encode(bytes_length, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +TEST(ArrayDeserializer, ShouldConsumeCorrectAmountOfData) { + const std::vector value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + serializeThenDeserializeAndCheckEquality>( + value); +} + +TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { + // given + ArrayDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t len = -1; // ARRAY accepts length >= 0. + encoder.encode(len, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +TEST(NullableArrayDeserializer, ShouldConsumeCorrectAmountOfData) { + const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + serializeThenDeserializeAndCheckEquality< + NullableArrayDeserializer>(value); +} + +TEST(NullableArrayDeserializer, ShouldConsumeNullArray) { + const NullableArray value = absl::nullopt; + serializeThenDeserializeAndCheckEquality< + NullableArrayDeserializer>(value); +} + +TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { + // given + NullableArrayDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t len = -2; // -1 is OK for NULLABLE_ARRAY. + encoder.encode(len, buffer); + + absl::string_view data = {getRawData(buffer), 1024}; + + // when + // then + EXPECT_THROW(testee.feed(data), EnvoyException); +} + +} // namespace SerializationTest +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_utilities.cc b/test/extensions/filters/network/kafka/serialization_utilities.cc new file mode 100644 index 0000000000000..95129da2e4003 --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_utilities.cc @@ -0,0 +1,42 @@ +#include "test/extensions/filters/network/kafka/serialization_utilities.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +void assertStringViewIncrement(absl::string_view incremented, absl::string_view original, + size_t difference) { + + ASSERT_EQ(incremented.data(), original.data() + difference); + ASSERT_EQ(incremented.size(), original.size() - difference); +} + +const char* getRawData(const Buffer::OwnedImpl& buffer) { + uint64_t num_slices = buffer.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); +} + +void CapturingRequestCallback::onMessage(AbstractRequestSharedPtr message) { + captured_.push_back(message); +} + +void CapturingRequestCallback::onFailedParse(RequestParseFailureSharedPtr failure_data) { + parse_failures_.push_back(failure_data); +} + +const std::vector& CapturingRequestCallback::getCaptured() const { + return captured_; +} + +const std::vector& +CapturingRequestCallback::getParseFailures() const { + return parse_failures_; +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h new file mode 100644 index 0000000000000..d15071ffa23b6 --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -0,0 +1,145 @@ +#pragma once + +#include "common/buffer/buffer_impl.h" +#include "common/common/stack_array.h" + +#include "extensions/filters/network/kafka/request_codec.h" +#include "extensions/filters/network/kafka/serialization.h" + +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Verifies that 'incremented' string view is actually 'original' string view, that has incremented + * by 'difference' bytes. + */ +void assertStringViewIncrement(absl::string_view incremented, absl::string_view original, + size_t difference); + +// Helper function converting buffer to raw bytes. +const char* getRawData(const Buffer::OwnedImpl& buffer); + +// Exactly what is says on the tin: +// 1. serialize expected using Encoder, +// 2. deserialize byte array using testee deserializer, +// 3. verify that testee is ready, and its result is equal to expected, +// 4. verify that data pointer moved correct amount, +// 5. feed testee more data, +// 6. verify that nothing more was consumed (because the testee has been ready since step 3). +template +void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder{-1}; + const size_t written = encoder.encode(expected, buffer); + // Insert garbage after serialized payload. + const size_t garbage_size = encoder.encode(Bytes(10000), buffer); + + // Tell parser that there is more data, it should never consume more than written. + const absl::string_view orig_data = {getRawData(buffer), written + garbage_size}; + absl::string_view data = orig_data; + + // when + const size_t consumed = testee.feed(data); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + assertStringViewIncrement(data, orig_data, consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + assertStringViewIncrement(data, orig_data, consumed); +} + +// Does the same thing as the above test, but instead of providing whole data at one, it provides +// it in N one-byte chunks. +// This verifies if deserializer keeps state properly (no overwrites etc.). +template +void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder{-1}; + const size_t written = encoder.encode(expected, buffer); + // Insert garbage after serialized payload. + const size_t garbage_size = encoder.encode(Bytes(10000), buffer); + + const absl::string_view orig_data = {getRawData(buffer), written + garbage_size}; + + // when + absl::string_view data = orig_data; + size_t consumed = 0; + for (size_t i = 0; i < written; ++i) { + data = {data.data(), 1}; // Consume data byte-by-byte. + size_t step = testee.feed(data); + consumed += step; + ASSERT_EQ(step, 1); + ASSERT_EQ(data.size(), 0); + } + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + + ASSERT_EQ(data.data(), orig_data.data() + consumed); + + // when - 2 + absl::string_view more_data = {data.data(), garbage_size}; + const size_t consumed2 = testee.feed(more_data); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(more_data.data(), data.data()); + ASSERT_EQ(more_data.size(), garbage_size); +} + +// Wrapper to run both tests. +template void serializeThenDeserializeAndCheckEquality(AT expected) { + serializeThenDeserializeAndCheckEqualityInOneGo(expected); + serializeThenDeserializeAndCheckEqualityWithChunks(expected); +} + +/** + * Request callback that captures the messages. + */ +class CapturingRequestCallback : public RequestCallback { +public: + /** + * Stores the message. + */ + virtual void onMessage(AbstractRequestSharedPtr request) override; + + /** + * Returns the stored messages. + */ + const std::vector& getCaptured() const; + + virtual void onFailedParse(RequestParseFailureSharedPtr failure_data) override; + + const std::vector& getParseFailures() const; + +private: + std::vector captured_; + std::vector parse_failures_; +}; + +typedef std::shared_ptr CapturingRequestCallbackSharedPtr; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/tools/spelling_dictionary.txt b/tools/spelling_dictionary.txt index 1240f39f89a5f..7a399c13abe4f 100644 --- a/tools/spelling_dictionary.txt +++ b/tools/spelling_dictionary.txt @@ -384,7 +384,9 @@ dereferencing deregistered deserialization deserialize +deserialized deserializer +deserializers dest destructor destructors @@ -578,6 +580,7 @@ parameterizing params paren parentid +parsers pcall pcap pclose @@ -701,6 +704,7 @@ templating templatize templatized templatizing +testee th thru tm