From 4ec6b337544df12c3832f156c22fc392cd503084 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 25 Sep 2018 15:52:52 +0100 Subject: [PATCH 01/29] WIP: Kafka codec Signed-off-by: Adam Kotwasinski --- source/common/common/logger.h | 1 + source/extensions/extensions_build_config.bzl | 1 + source/extensions/filters/network/kafka/BUILD | 82 + .../extensions/filters/network/kafka/codec.h | 34 + .../filters/network/kafka/kafka_protocol.h | 123 ++ .../filters/network/kafka/kafka_request.cc | 135 ++ .../filters/network/kafka/kafka_request.h | 1316 +++++++++++++++++ .../filters/network/kafka/kafka_types.h | 31 + .../filters/network/kafka/message.h | 30 + .../extensions/filters/network/kafka/parser.h | 48 + .../filters/network/kafka/request_codec.cc | 62 + .../filters/network/kafka/request_codec.h | 69 + .../filters/network/kafka/serialization.h | 772 ++++++++++ .../filters/network/well_known_names.h | 2 + test/extensions/filters/network/kafka/BUILD | 42 + .../network/kafka/kafka_request_test.cc | 173 +++ .../network/kafka/request_codec_test.cc | 538 +++++++ .../network/kafka/serialization_test.cc | 405 +++++ 18 files changed, 3864 insertions(+) create mode 100644 source/extensions/filters/network/kafka/BUILD create mode 100644 source/extensions/filters/network/kafka/codec.h create mode 100644 source/extensions/filters/network/kafka/kafka_protocol.h create mode 100644 source/extensions/filters/network/kafka/kafka_request.cc create mode 100644 source/extensions/filters/network/kafka/kafka_request.h create mode 100644 source/extensions/filters/network/kafka/kafka_types.h create mode 100644 source/extensions/filters/network/kafka/message.h create mode 100644 source/extensions/filters/network/kafka/parser.h create mode 100644 source/extensions/filters/network/kafka/request_codec.cc create mode 100644 source/extensions/filters/network/kafka/request_codec.h create mode 100644 source/extensions/filters/network/kafka/serialization.h create mode 100644 test/extensions/filters/network/kafka/BUILD create mode 100644 test/extensions/filters/network/kafka/kafka_request_test.cc create mode 100644 test/extensions/filters/network/kafka/request_codec_test.cc create mode 100644 test/extensions/filters/network/kafka/serialization_test.cc diff --git a/source/common/common/logger.h b/source/common/common/logger.h index 341caa20c8c96..ab27722a1a2e3 100644 --- a/source/common/common/logger.h +++ b/source/common/common/logger.h @@ -36,6 +36,7 @@ namespace Logger { FUNCTION(http) \ FUNCTION(http2) \ FUNCTION(hystrix) \ + FUNCTION(kafka) \ FUNCTION(lua) \ FUNCTION(main) \ FUNCTION(misc) \ diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index 5fb01c97c821c..c168a3cd0c7f1 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -65,6 +65,7 @@ EXTENSIONS = { "envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", "envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", "envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", + "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:config", "envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", "envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", "envoy.filters.network.rbac": "//source/extensions/filters/network/rbac:config", diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD new file mode 100644 index 0000000000000..218e7fc5076a2 --- /dev/null +++ b/source/extensions/filters/network/kafka/BUILD @@ -0,0 +1,82 @@ +licenses(["notice"]) # Apache 2 + +# Kafka network filter. +# Public docs: docs/root/configuration/network_filters/kafka_filter.rst + +load( + "//bazel:envoy_build_system.bzl", + "envoy_cc_library", + "envoy_package", +) + +envoy_package() + +envoy_cc_library( + name = "config", +) + +envoy_cc_library( + name = "kafka_request_codec_lib", + srcs = ["request_codec.cc"], + hdrs = [ + "codec.h", + "request_codec.h", + ], + deps = [ + ":kafka_request_lib", + "//source/common/buffer:buffer_lib", + ], +) + +envoy_cc_library( + name = "kafka_request_lib", + srcs = ["kafka_request.cc"], + hdrs = ["kafka_request.h"], + deps = [ + ":parser_lib", + ":serialization_lib", + "//source/common/common:assert_lib", + "//source/common/common:minimal_logger_lib", + ], +) + +envoy_cc_library( + name = "parser_lib", + hdrs = ["parser.h"], + deps = [ + ":kafka_protocol_lib", + ":message_lib", + "//source/common/common:minimal_logger_lib", + ], +) + +envoy_cc_library( + name = "message_lib", + hdrs = [ + "message.h", + ], + deps = [ + ], +) + +envoy_cc_library( + name = "serialization_lib", + hdrs = ["serialization.h"], + deps = [ + ":kafka_protocol_lib", + "//include/envoy/buffer:buffer_interface", + "//source/common/common:byte_order_lib", + ], +) + +envoy_cc_library( + name = "kafka_protocol_lib", + hdrs = [ + "kafka_protocol.h", + "kafka_types.h", + ], + external_deps = ["abseil_optional"], + deps = [ + "//source/common/common:macros", + ], +) diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h new file mode 100644 index 0000000000000..0e255b509bc21 --- /dev/null +++ b/source/extensions/filters/network/kafka/codec.h @@ -0,0 +1,34 @@ +#pragma once + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Kafka message decoder + * @tparam MT message type (Kafka request or Kafka response) + */ +template class MessageDecoder { +public: + virtual ~MessageDecoder() = default; + virtual void onData(Buffer::Instance& data) PURE; +}; + +/** + * Kafka message decoder + * @tparam MT message type (Kafka request or Kafka response) + */ +template class MessageEncoder { +public: + virtual ~MessageEncoder() = default; + virtual void encode(const MT& message) PURE; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_protocol.h b/source/extensions/filters/network/kafka/kafka_protocol.h new file mode 100644 index 0000000000000..f6452abd3e859 --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_protocol.h @@ -0,0 +1,123 @@ +#pragma once + +#include + +#include "common/common/macros.h" + +#include "extensions/filters/network/kafka/kafka_types.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// from http://kafka.apache.org/protocol.html#protocol_api_keys +enum RequestType : INT16 { + Produce = 0, + Fetch = 1, + ListOffsets = 2, + Metadata = 3, + LeaderAndIsr = 4, + StopReplica = 5, + UpdateMetadata = 6, + ControlledShutdown = 7, + OffsetCommit = 8, + OffsetFetch = 9, + FindCoordinator = 10, + JoinGroup = 11, + Heartbeat = 12, + LeaveGroup = 13, + SyncGroup = 14, + DescribeGroups = 15, + ListGroups = 16, + SaslHandshake = 17, + ApiVersions = 18, + CreateTopics = 19, + DeleteTopics = 20, + DeleteRecords = 21, + InitProducerId = 22, + OffsetForLeaderEpoch = 23, + AddPartitionsToTxn = 24, + AddOffsetsToTxn = 25, + EndTxn = 26, + WriteTxnMarkers = 27, + TxnOffsetCommit = 28, + DescribeAcls = 29, + CreateAcls = 30, + DeleteAcls = 31, + DescribeConfigs = 32, + AlterConfigs = 33, + AlterReplicaLogDirs = 34, + DescribeLogDirs = 35, + SaslAuthenticate = 36, + CreatePartitions = 37, + CreateDelegationToken = 38, + RenewDelegationToken = 39, + ExpireDelegationToken = 40, + DescribeDelegationToken = 41, + DeleteGroups = 42 +}; + +struct RequestSpec { + const INT16 api_key_; + const std::string name_; +}; + +struct KafkaRequest { + + // clang-format off + static const std::vector& requests() { + CONSTRUCT_ON_FIRST_USE( + std::vector, + {RequestType::Produce, "Produce"}, + {RequestType::Fetch, "Fetch"}, + {RequestType::ListOffsets, "ListOffsets"}, + {RequestType::Metadata, "Metadata"}, + {RequestType::LeaderAndIsr, "LeaderAndIsr"}, + {RequestType::StopReplica, "StopReplica"}, + {RequestType::UpdateMetadata, "UpdateMetadata"}, + {RequestType::ControlledShutdown, "ControlledShutdown"}, + {RequestType::OffsetCommit, "OffsetCommit"}, + {RequestType::OffsetFetch, "OffsetFetch"}, + {RequestType::FindCoordinator, "FindCoordinator"}, + {RequestType::JoinGroup, "JoinGroup"}, + {RequestType::Heartbeat, "Heartbeat"}, + {RequestType::LeaveGroup, "LeaveGroup"}, + {RequestType::SyncGroup, "SyncGroup"}, + {RequestType::DescribeGroups, "DescribeGroups"}, + {RequestType::ListGroups, "ListGroups"}, + {RequestType::SaslHandshake, "SaslHandshake"}, + {RequestType::ApiVersions, "ApiVersions"}, + {RequestType::CreateTopics, "CreateTopics"}, + {RequestType::DeleteTopics, "DeleteTopics"}, + {RequestType::DeleteRecords, "DeleteRecords"}, + {RequestType::InitProducerId, "InitProducerId"}, + {RequestType::OffsetForLeaderEpoch, "OffsetForLeaderEpoch"}, + {RequestType::AddPartitionsToTxn, "AddPartitionsToTxn"}, + {RequestType::AddOffsetsToTxn, "AddOffsetsToTxn"}, + {RequestType::EndTxn, "EndTxn"}, + {RequestType::WriteTxnMarkers, "WriteTxnMarkers"}, + {RequestType::TxnOffsetCommit, "TxnOffsetCommit"}, + {RequestType::DescribeAcls, "DescribeAcls"}, + {RequestType::CreateAcls, "CreateAcls"}, + {RequestType::DeleteAcls, "DeleteAcls"}, + {RequestType::DescribeConfigs, "DescribeConfigs"}, + {RequestType::AlterConfigs, "AlterConfigs"}, + {RequestType::AlterReplicaLogDirs, "AlterReplicaLogDirs"}, + {RequestType::DescribeLogDirs, "DescribeLogDirs"}, + {RequestType::SaslAuthenticate, "SaslAuthenticate"}, + {RequestType::CreatePartitions, "CreatePartitions"}, + {RequestType::CreateDelegationToken, "CreateDelegationToken"}, + {RequestType::RenewDelegationToken, "RenewDelegationToken"}, + {RequestType::ExpireDelegationToken, "ExpireDelegationToken"}, + {RequestType::DescribeDelegationToken, "DescribeDelegationToken"}, + {RequestType::DeleteGroups, "DeleteGroups"} + ); + } + // clang-format on +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc new file mode 100644 index 0000000000000..e4ec7b0a8d25f --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -0,0 +1,135 @@ +#include "extensions/filters/network/kafka/kafka_request.h" + +#include "extensions/filters/network/kafka/kafka_protocol.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === REQUEST PARSER MAPPING (REQUEST TYPE => PARSER) ========================= + +GeneratorMap computeGeneratorMap(std::vector specs) { + GeneratorMap result; + for (auto& spec : specs) { + auto generators = result[spec.api_key_]; + if (!generators) { + generators = std::make_shared>(); + result[spec.api_key_] = generators; + } + for (INT16 api_version : spec.api_versions_) { + (*generators)[api_version] = spec.generator_; + } + } + + return result; +} + +#define PARSER_SPEC(REQUEST_NAME, PARSER_VERSION, ...) \ + ParserSpec { \ + RequestType::REQUEST_NAME, {__VA_ARGS__}, [](RequestContextSharedPtr arg) -> ParserSharedPtr { \ + return std::make_shared(arg); \ + } \ + } + +const RequestParserResolver RequestParserResolver::KAFKA_0_11{{ + ParserSpec{RequestType::Produce, + {0, 1, 2}, + [](RequestContextSharedPtr arg) -> ParserSharedPtr { + return std::make_shared(arg); + }}, + ParserSpec{RequestType::Produce, + {3}, + [](RequestContextSharedPtr arg) -> ParserSharedPtr { + return std::make_shared(arg); + }}, + PARSER_SPEC(Fetch, V0, 0, 1, 2), + PARSER_SPEC(Fetch, V3, 3), + PARSER_SPEC(Fetch, V4, 4), + PARSER_SPEC(Fetch, V5, 5), + PARSER_SPEC(ListOffsets, V0, 0), + PARSER_SPEC(ListOffsets, V1, 1), + PARSER_SPEC(ListOffsets, V2, 2), + PARSER_SPEC(Metadata, V0, 0, 1, 2, 3), + PARSER_SPEC(Metadata, V4, 4), + PARSER_SPEC(LeaderAndIsr, V0, 0), + PARSER_SPEC(StopReplica, V0, 0), + PARSER_SPEC(UpdateMetadata, V0, 0), + PARSER_SPEC(UpdateMetadata, V1, 1), + PARSER_SPEC(UpdateMetadata, V2, 2), + PARSER_SPEC(UpdateMetadata, V3, 3), + PARSER_SPEC(ControlledShutdown, V1, 1), + PARSER_SPEC(OffsetCommit, V0, 0), + PARSER_SPEC(OffsetCommit, V1, 1), + PARSER_SPEC(OffsetCommit, V2, 2, 3), + PARSER_SPEC(OffsetFetch, V0, 0, 1, 2, 3), + // XXX(adam.kotwasinski) missing request types here + PARSER_SPEC(ApiVersions, V0, 0, 1), +}}; + +ParserSharedPtr RequestParserResolver::createParser(INT16 api_key, INT16 api_version, + RequestContextSharedPtr context) const { + const auto api_versions_ptr = generators_.find(api_key); + // unknown api_key + if (generators_.end() == api_versions_ptr) { + return std::make_shared(context); + } + const auto api_versions = api_versions_ptr->second; + + // unknown api_version + const auto generator = api_versions->find(api_version); + if (api_versions->end() == generator) { + return std::make_shared(context); + } + + // found matching parser generator, create parser + return generator->second(context); +} + +// === HEADER PARSERS ========================================================== + +ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining) { + buffer_.feed(buffer, remaining); + if (buffer_.ready()) { + context_->remaining_request_size_ = buffer_.get(); + return ParseResponse::nextParser( + std::make_shared(parser_resolver_, context_)); + } else { + return ParseResponse::stillWaiting(); + } +} + +ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remaining) { + context_->remaining_request_size_ -= buffer_.feed(buffer, remaining); + + if (buffer_.ready()) { + RequestHeader request_header = buffer_.get(); + context_->request_header_ = request_header; + ParserSharedPtr next_parser = parser_resolver_.createParser( + request_header.api_key_, request_header.api_version_, context_); + return ParseResponse::nextParser(next_parser); + } else { + return ParseResponse::stillWaiting(); + } +} + +// === UNKNOWN REQUEST ========================================================= + +ParseResponse SentinelConsumer::parse(const char*& buffer, uint64_t& remaining) { + const size_t min = std::min(context_->remaining_request_size_, remaining); + buffer += min; + remaining -= min; + context_->remaining_request_size_ -= min; + if (0 == context_->remaining_request_size_) { + return ParseResponse::parsedMessage( + std::make_shared(context_->request_header_)); + } else { + return ParseResponse::stillWaiting(); + } +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h new file mode 100644 index 0000000000000..c7b8754c36342 --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -0,0 +1,1316 @@ +#pragma once + +#include + +#include "envoy/common/exception.h" + +#include "common/common/assert.h" + +#include "extensions/filters/network/kafka/kafka_protocol.h" +#include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === VECTOR ================================================================== + +template std::ostream& operator<<(std::ostream& os, const std::vector& arg) { + os << "["; + for (auto iter = arg.begin(); iter != arg.end(); iter++) { + if (iter != arg.begin()) { + os << ", "; + } + os << *iter; + } + os << "]"; + return os; +} + +template std::ostream& operator<<(std::ostream& os, const absl::optional& arg) { + if (arg.has_value()) { + os << *arg; + } else { + os << ""; + } + return os; +} + +// === REQUEST HEADER ========================================================== + +struct RequestHeader { + INT16 api_key_; + INT16 api_version_; + INT32 correlation_id_; + NULLABLE_STRING client_id_; + + bool operator==(const RequestHeader& rhs) const { + return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && + correlation_id_ == rhs.correlation_id_ && client_id_ == rhs.client_id_; + }; + + friend std::ostream& operator<<(std::ostream& os, const RequestHeader& arg) { + return os << "{api_key=" << arg.api_key_ << ", api_version=" << arg.api_version_ + << ", correlation_id=" << arg.correlation_id_ << ", client_id=" << arg.client_id_ + << "}"; + }; +}; + +struct RequestContext { + INT32 remaining_request_size_{0}; + RequestHeader request_header_{}; + + friend std::ostream& operator<<(std::ostream& os, const RequestContext& arg) { + return os << "{header=" << arg.request_header_ << ", remaining=" << arg.remaining_request_size_ + << "}"; + } +}; + +typedef std::shared_ptr RequestContextSharedPtr; + +// === REQUEST PARSER MAPPING (REQUEST TYPE => PARSER) ========================= + +// a function generating a parser with given context +typedef std::function GeneratorFunction; + +// two-level map: api_key -> api_version -> generator function +typedef std::unordered_map>> + GeneratorMap; + +struct ParserSpec { + const INT16 api_key_; + const std::vector api_versions_; + const GeneratorFunction generator_; +}; + +// helper function that generates a map from specs looking like { api_key, api_versions... } +GeneratorMap computeGeneratorMap(std::vector arg); + +/** + * Provides the parser that is responsible for consuming the request-specific data + * In other words: provides (api_key, api_version) -> Parser function + */ +class RequestParserResolver { +public: + RequestParserResolver(std::vector arg) : generators_{computeGeneratorMap(arg)} {}; + virtual ~RequestParserResolver() = default; + + virtual ParserSharedPtr createParser(INT16 api_key, INT16 api_version, + RequestContextSharedPtr context) const; + + static const RequestParserResolver KAFKA_0_11; + +private: + GeneratorMap generators_; +}; + +// === INITIAL PARSERS ========================================================= + +/** + * Request start parser just consumes the length of request + */ +class RequestStartParser : public Parser { +public: + RequestStartParser(const RequestParserResolver& parser_resolver) + : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; + + ParseResponse parse(const char*& buffer, uint64_t& remaining); + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + Int32Buffer buffer_; +}; + +class RequestHeaderBuffer : public CompositeBuffer {}; + +/** + * Request header parser consumes request header + */ +class RequestHeaderParser : public Parser { +public: + RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) + : parser_resolver_{parser_resolver}, context_{context} {}; + + ParseResponse parse(const char*& buffer, uint64_t& remaining); + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + RequestHeaderBuffer buffer_; +}; + +// === BUFFERED PARSER ========================================================= + +/** + * Buffered parser uses a single buffer to construct a response + * This parser is responsible for consuming request-specific data (e.g. topic names) and always + * returns a parsed message + */ +template class BufferedParser : public Parser { +public: + BufferedParser(RequestContextSharedPtr context) : context_{context} {}; + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + +protected: + RequestContextSharedPtr context_; + BT buffer_; +}; + +template +ParseResponse BufferedParser::parse(const char*& buffer, uint64_t& remaining) { + context_->remaining_request_size_ -= buffer_.feed(buffer, remaining); + if (buffer_.ready()) { + // after a successful parse, there should be nothing left + ASSERT(0 == context_->remaining_request_size_); + RT request = buffer_.get(); + request.header() = context_->request_header_; + ENVOY_LOG(trace, "parsed request {}: {}", *context_, request); + MessageSharedPtr msg = std::make_shared(request); + return ParseResponse::parsedMessage(msg); + } else { + return ParseResponse::stillWaiting(); + } +} + +// names of Buffers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names + +#define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ + class REQUEST_TYPE##VERSION##Parser \ + : public BufferedParser { \ + public: \ + REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : BufferedParser{ctx} {}; \ + }; + +// === ABSTRACT REQUEST ======================================================== + +class Request : public Message { +public: + /** + * Request header fields need to be initialized by user in case of newly created requests + */ + Request(INT16 api_key) : request_header_{api_key, 0, 0, ""} {}; + + Request(const RequestHeader& request_header) : request_header_{request_header} {}; + + RequestHeader& header() { return request_header_; } + + INT16& apiVersion() { return request_header_.api_version_; } + INT16 apiVersion() const { return request_header_.api_version_; } + + INT32& correlationId() { return request_header_.correlation_id_; } + + NULLABLE_STRING& clientId() { return request_header_.client_id_; } + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(request_header_.api_key_, dst); + written += encoder.encode(request_header_.api_version_, dst); + written += encoder.encode(request_header_.correlation_id_, dst); + written += encoder.encode(request_header_.client_id_, dst); + written += encodeDetails(dst, encoder); + return written; + } + + std::ostream& print(std::ostream& os) const override final { + os << request_header_ << " "; // not very pretty + return printDetails(os); + } + +protected: + virtual size_t encodeDetails(Buffer::Instance&, EncodingContext&) const PURE; + + virtual std::ostream& printDetails(std::ostream&) const PURE; + + RequestHeader request_header_; +}; + +// === PRODUCE (0) ============================================================= + +/** + * Produce request parser is a special case that has two corresponding parsers + * One parser captures data, the other one does not, only saving the length of data provided + * This might be used in filters that do not need access to data (e.g. only want to update request + * type metrics) + */ + +// holds data sent by client +struct FatProducePartition { + const INT32 partition_; + const NULLABLE_BYTES data_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(data_, dst); + return written; + } + + bool operator==(const FatProducePartition& rhs) const { + return partition_ == rhs.partition_ && data_ == rhs.data_; + }; + + friend std::ostream& operator<<(std::ostream& os, const FatProducePartition& arg) { + os << "{partition=" << arg.partition_ << ", data(size)="; + if (arg.data_.has_value()) { + os << arg.data_->size(); + } else { + os << ""; + } + return os << "}"; + } +}; + +// does not carry data, only its length +struct ThinProducePartition { + const INT32 partition_; + const INT32 data_size_; + + size_t encode(Buffer::Instance&, EncodingContext&) const { + throw EnvoyException("ThinProducePartition cannot be encoded"); + } + + bool operator==(const ThinProducePartition& rhs) const { + return partition_ == rhs.partition_ && data_size_ == rhs.data_size_; + }; + + friend std::ostream& operator<<(std::ostream& os, const ThinProducePartition& arg) { + return os << "{partition=" << arg.partition_ << ", data_size=" << arg.data_size_ << "}"; + } +}; + +template struct ProduceTopic { + const STRING topic_; + const NULLABLE_ARRAY partitions_; + + bool operator==(const ProduceTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + friend std::ostream& operator<<(std::ostream& os, const ProduceTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; + } +}; + +typedef ProduceTopic FatProduceTopic; +typedef ProduceTopic ThinProduceTopic; + +/** + * PT carries partition type, which can be capturing (contains bytes) or non-capturing (contains + * bytes' length only) + */ +template class ProduceRequest : public Request { +public: + // v0 .. v2 + ProduceRequest(INT16 acks, INT32 timeout, NULLABLE_ARRAY> topics) + : ProduceRequest(absl::nullopt, acks, timeout, topics){}; + + // v3 + ProduceRequest(NULLABLE_STRING transactional_id, INT16 acks, INT32 timeout, + NULLABLE_ARRAY> topics) + : Request{RequestType::Produce}, + transactional_id_{transactional_id}, acks_{acks}, timeout_{timeout}, topics_{topics} {}; + + bool operator==(const ProduceRequest& rhs) const { + return request_header_ == rhs.request_header_ && transactional_id_ == rhs.transactional_id_ && + acks_ == rhs.acks_ && timeout_ == rhs.timeout_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + if (request_header_.api_version_ >= 3) { + written += encoder.encode(transactional_id_, dst); + } + written += encoder.encode(acks_, dst); + written += encoder.encode(timeout_, dst); + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{transactional_id=" << transactional_id_ << ", acks=" << acks_ + << ", timeout=" << timeout_ << ", topics=" << topics_ << "}"; + }; + +private: + const NULLABLE_STRING transactional_id_; + const INT16 acks_; + const INT32 timeout_; + const NULLABLE_ARRAY> topics_; +}; + +typedef ProduceRequest FatProduceRequest; +typedef ProduceRequest ThinProduceRequest; + +// clang-format off +class ThinProducePartitionArrayBuffer : public ArrayBuffer> {}; +class ProducePartitionArrayBuffer : public ArrayBuffer> {}; + +class ThinProduceTopicArrayBuffer : public ArrayBuffer> {}; +class ProduceTopicArrayBuffer : public ArrayBuffer> {}; + +class ThinProduceRequestV0Buffer : public CompositeBuffer {}; +class ThinProduceRequestV3Buffer : public CompositeBuffer {}; +class FatProduceRequestV0Buffer : public CompositeBuffer {}; +class FatProduceRequestV3Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(ThinProduceRequest, V0); +DEFINE_REQUEST_PARSER(ThinProduceRequest, V3); +DEFINE_REQUEST_PARSER(FatProduceRequest, V0); +DEFINE_REQUEST_PARSER(FatProduceRequest, V3); +// clang-format on + +// === FETCH (1) =============================================================== + +struct FetchRequestPartition { + const INT32 partition_; + const INT64 fetch_offset_; + const INT64 log_start_offset_; // since v5 + const INT32 max_bytes_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(fetch_offset_, dst); + if (encoder.apiVersion() >= 5) { + written += encoder.encode(log_start_offset_, dst); + } + written += encoder.encode(max_bytes_, dst); + return written; + } + + friend std::ostream& operator<<(std::ostream& os, const FetchRequestPartition& arg) { + return os << "{partition=" << arg.partition_ << ", fetch_offset=" << arg.fetch_offset_ + << ", log_start_offset=" << arg.log_start_offset_ << ", max_bytes=" << arg.max_bytes_ + << "}"; + } + + bool operator==(const FetchRequestPartition& rhs) const { + return partition_ == rhs.partition_ && fetch_offset_ == rhs.fetch_offset_ && + log_start_offset_ == rhs.log_start_offset_ && max_bytes_ == rhs.max_bytes_; + }; + + // v0 .. v4 + FetchRequestPartition(INT32 partition, INT64 fetch_offset, INT32 max_bytes) + : FetchRequestPartition(partition, fetch_offset, -1, max_bytes){}; + + // v5 + FetchRequestPartition(INT32 partition, INT64 fetch_offset, INT64 log_start_offset, + INT32 max_bytes) + : partition_{partition}, fetch_offset_{fetch_offset}, log_start_offset_{log_start_offset}, + max_bytes_{max_bytes} {}; +}; + +struct FetchRequestTopic { + const STRING topic_; + const NULLABLE_ARRAY partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const FetchRequestTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + friend std::ostream& operator<<(std::ostream& os, const FetchRequestTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; + } +}; + +class FetchRequest : public Request { +public: + // v0 .. v2 + FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, + NULLABLE_ARRAY topics) + : FetchRequest(replica_id, max_wait_time, min_bytes, -1, topics){}; + + // v3 + FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, INT32 max_bytes, + NULLABLE_ARRAY topics) + : FetchRequest(replica_id, max_wait_time, min_bytes, max_bytes, -1, topics){}; + + // v4 .. v5 + FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, INT32 max_bytes, + INT8 isolation_level, NULLABLE_ARRAY topics) + : Request{RequestType::Fetch}, replica_id_{replica_id}, max_wait_time_{max_wait_time}, + min_bytes_{min_bytes}, max_bytes_{max_bytes}, + isolation_level_{isolation_level}, topics_{topics} {}; + + bool operator==(const FetchRequest& rhs) const { + return request_header_ == rhs.request_header_ && replica_id_ == rhs.replica_id_ && + max_wait_time_ == rhs.max_wait_time_ && min_bytes_ == rhs.min_bytes_ && + max_bytes_ == rhs.max_bytes_ && isolation_level_ == rhs.isolation_level_ && + topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + INT16 api_version = request_header_.api_version_; + written += encoder.encode(replica_id_, dst); + written += encoder.encode(max_wait_time_, dst); + written += encoder.encode(min_bytes_, dst); + if (api_version >= 3) { + written += encoder.encode(max_bytes_, dst); + } + if (api_version >= 4) { + written += encoder.encode(isolation_level_, dst); + } + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{replica_id=" << replica_id_ << ", max_wait_time=" << max_wait_time_ + << ", min_bytes=" << min_bytes_ << ", max_bytes=" << max_bytes_ + << ", isolation_level=" << static_cast(isolation_level_) + << ", topics=" << topics_ << "}"; + } + +private: + const INT32 replica_id_; + const INT32 max_wait_time_; + const INT32 min_bytes_; + const INT32 max_bytes_; // since v3 + const INT8 isolation_level_; // since v4 + const NULLABLE_ARRAY topics_; +}; + +// clang-format off +class FetchRequestPartitionV0Buffer : public CompositeBuffer {}; +class FetchRequestPartitionV0ArrayBuffer : public ArrayBuffer {}; +class FetchRequestTopicV0Buffer : public CompositeBuffer {}; +class FetchRequestTopicV0ArrayBuffer : public ArrayBuffer {}; + +class FetchRequestPartitionV5Buffer : public CompositeBuffer {}; +class FetchRequestPartitionV5ArrayBuffer : public ArrayBuffer {}; +class FetchRequestTopicV5Buffer : public CompositeBuffer {}; +class FetchRequestTopicV5ArrayBuffer : public ArrayBuffer {}; + +class FetchRequestV0Buffer : public CompositeBuffer {}; +class FetchRequestV3Buffer : public CompositeBuffer {}; +class FetchRequestV4Buffer : public CompositeBuffer {}; +class FetchRequestV5Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(FetchRequest, V0); +DEFINE_REQUEST_PARSER(FetchRequest, V3); +DEFINE_REQUEST_PARSER(FetchRequest, V4); +DEFINE_REQUEST_PARSER(FetchRequest, V5); +// clang-format on + +// === LIST OFFSETS (2) ======================================================== + +struct ListOffsetsPartition { + const INT32 partition_; + const INT64 timestamp_; + const INT32 max_num_offsets_; // only v0 + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(timestamp_, dst); + if (encoder.apiVersion() == 0) { + written += encoder.encode(max_num_offsets_, dst); + } + return written; + } + + bool operator==(const ListOffsetsPartition& rhs) const { + return partition_ == rhs.partition_ && timestamp_ == rhs.timestamp_ && + max_num_offsets_ == rhs.max_num_offsets_; + }; + + friend std::ostream& operator<<(std::ostream& os, const ListOffsetsPartition& arg) { + return os << "{partition=" << arg.partition_ << ", timestamp=" << arg.timestamp_ + << ", max_num_offsets=" << arg.max_num_offsets_ << "}"; + } + + // v0 + ListOffsetsPartition(INT32 partition, INT64 timestamp, INT32 max_num_offsets) + : partition_{partition}, timestamp_{timestamp}, max_num_offsets_{max_num_offsets} {}; + + // v1 .. v2 + ListOffsetsPartition(INT32 partition, INT64 timestamp) + : ListOffsetsPartition(partition, timestamp, -1){}; +}; + +struct ListOffsetsTopic { + const STRING topic_; + const NULLABLE_ARRAY partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const ListOffsetsTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + friend std::ostream& operator<<(std::ostream& os, const ListOffsetsTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; + } +}; + +class ListOffsetsRequest : public Request { +public: + // v0 .. v1 + ListOffsetsRequest(INT32 replica_id, NULLABLE_ARRAY topics) + : ListOffsetsRequest(replica_id, -1, topics){}; + + // v2 + ListOffsetsRequest(INT32 replica_id, INT8 isolation_level, + NULLABLE_ARRAY topics) + : Request{RequestType::ListOffsets}, replica_id_{replica_id}, + isolation_level_{isolation_level}, topics_{topics} {}; + + bool operator==(const ListOffsetsRequest& rhs) const { + return request_header_ == rhs.request_header_ && replica_id_ == rhs.replica_id_ && + isolation_level_ == rhs.isolation_level_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(replica_id_, dst); + if (encoder.apiVersion() >= 2) { + written += encoder.encode(isolation_level_, dst); + } + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{replica_id=" << replica_id_ + << ", isolation_level=" << static_cast(isolation_level_) + << ", topics=" << topics_ << "}"; + } + +private: + const INT32 replica_id_; + const INT8 isolation_level_; // since v2 + const NULLABLE_ARRAY topics_; +}; + +// clang-format off +class ListOffsetsPartitionV0Buffer : public CompositeBuffer {}; +class ListOffsetsPartitionV0ArrayBuffer : public ArrayBuffer {}; +class ListOffsetsTopicV0Buffer : public CompositeBuffer {}; +class ListOffsetsTopicV0ArrayBuffer : public ArrayBuffer {}; + +class ListOffsetsPartitionV1Buffer : public CompositeBuffer {}; +class ListOffsetsPartitionV1ArrayBuffer : public ArrayBuffer {}; +class ListOffsetsTopicV1Buffer : public CompositeBuffer {}; +class ListOffsetsTopicV1ArrayBuffer : public ArrayBuffer {}; + +class ListOffsetsRequestV0Buffer : public CompositeBuffer {}; +class ListOffsetsRequestV1Buffer : public CompositeBuffer {}; +class ListOffsetsRequestV2Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(ListOffsetsRequest, V0); +DEFINE_REQUEST_PARSER(ListOffsetsRequest, V1); +DEFINE_REQUEST_PARSER(ListOffsetsRequest, V2); +// clang-format on + +// === METADATA (3) ============================================================ + +class MetadataRequest : public Request { +public: + // v0 .. v3 + MetadataRequest(NULLABLE_ARRAY topics) : MetadataRequest(topics, false){}; + + // v4 + MetadataRequest(NULLABLE_ARRAY topics, BOOLEAN allow_auto_topic_creation) + : Request{RequestType::Metadata}, topics_{topics}, allow_auto_topic_creation_{ + allow_auto_topic_creation} {}; + + bool operator==(const MetadataRequest& rhs) const { + return request_header_ == rhs.request_header_ && topics_ == rhs.topics_ && + allow_auto_topic_creation_ == rhs.allow_auto_topic_creation_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(topics_, dst); + if (encoder.apiVersion() >= 2) { + written += encoder.encode(allow_auto_topic_creation_, dst); + } + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{topics=" << topics_ + << ", allow_auto_topic_creation=" << allow_auto_topic_creation_ << "}"; + } + +private: + NULLABLE_ARRAY topics_; + BOOLEAN allow_auto_topic_creation_; // since v4 +}; + +// clang-format off +class MetadataRequestTopicV0Buffer : public ArrayBuffer {}; +class MetadataRequestV0Buffer : public CompositeBuffer {}; +class MetadataRequestV4Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(MetadataRequest, V0); +DEFINE_REQUEST_PARSER(MetadataRequest, V4); +// clang-format on + +// === LEADER-AND-ISR (4) ====================================================== + +/** + * This structure is used in both LeaderAndIsr v0 & UpdateMetadata + */ + +struct MetadataPartitionState { + const STRING topic_; + const INT32 partition_; + const INT32 controller_epoch_; + const INT32 leader_; + const INT32 leader_epoch_; + const NULLABLE_ARRAY isr_; + const INT32 zk_version_; + const NULLABLE_ARRAY replicas_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partition_, dst); + written += encoder.encode(controller_epoch_, dst); + written += encoder.encode(leader_, dst); + written += encoder.encode(leader_epoch_, dst); + written += encoder.encode(isr_, dst); + written += encoder.encode(zk_version_, dst); + written += encoder.encode(replicas_, dst); + return written; + } + + bool operator==(const MetadataPartitionState& rhs) const { + return topic_ == rhs.topic_ && partition_ == rhs.partition_ && + controller_epoch_ == rhs.controller_epoch_ && leader_ == rhs.leader_ && + leader_epoch_ == rhs.leader_epoch_ && isr_ == rhs.isr_ && + zk_version_ == rhs.zk_version_ && replicas_ == rhs.replicas_; + }; + + friend std::ostream& operator<<(std::ostream& os, const MetadataPartitionState& arg) { + return os << "{topic=" << arg.topic_ << ", partition=" << arg.partition_ + << ", controller_epoch=" << arg.controller_epoch_ << ", leader=" << arg.leader_ + << ", leader_epoch=" << arg.leader_epoch_ << ", isr=" << arg.isr_ + << ", zk_version=" << arg.zk_version_ << ", zk_version=" << arg.zk_version_ << "}"; + } +}; + +struct LeaderAndIsrLiveLeader { + const INT32 id_; + const STRING host_; + const INT32 port_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(id_, dst); + written += encoder.encode(host_, dst); + written += encoder.encode(port_, dst); + return written; + } + + bool operator==(const LeaderAndIsrLiveLeader& rhs) const { + return id_ == rhs.id_ && host_ == rhs.host_ && port_ == rhs.port_; + }; + + friend std::ostream& operator<<(std::ostream& os, const LeaderAndIsrLiveLeader& arg) { + return os << "{id=" << arg.id_ << ", host=" << arg.host_ << ", port=" << arg.port_ << "}"; + } +}; + +class LeaderAndIsrRequest : public Request { +public: + // v0 + LeaderAndIsrRequest(INT32 controller_id, INT32 controller_epoch, + NULLABLE_ARRAY partition_states, + NULLABLE_ARRAY live_readers) + : Request{RequestType::LeaderAndIsr}, controller_id_{controller_id}, + controller_epoch_{controller_epoch}, partition_states_{partition_states}, + live_readers_{live_readers} {}; + + bool operator==(const LeaderAndIsrRequest& rhs) const { + return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && + controller_epoch_ == rhs.controller_epoch_ && + partition_states_ == rhs.partition_states_ && live_readers_ == rhs.live_readers_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(controller_id_, dst); + written += encoder.encode(controller_epoch_, dst); + written += encoder.encode(partition_states_, dst); + written += encoder.encode(live_readers_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ + << ", partition_states=" << partition_states_ << ", live_readers=" << live_readers_ + << "}"; + } + +private: + const INT32 controller_id_; + const INT32 controller_epoch_; + const NULLABLE_ARRAY partition_states_; + const NULLABLE_ARRAY live_readers_; +}; + +// clang-format off +class MetadataPartitionStateV0Buffer : public CompositeBuffer, + Int32Buffer, + ArrayBuffer + > {}; +class MetadataPartitionStateV0ArrayBuffer : public ArrayBuffer {}; + +class LeaderAndIsrLiveLeaderV0Buffer : public CompositeBuffer {}; +class LeaderAndIsrLiveLeaderV0ArrayBuffer : public ArrayBuffer {}; + +class LeaderAndIsrRequestV0Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(LeaderAndIsrRequest, V0); +// clang-format on + +// === STOP REPLICA (5) ======================================================== + +struct StopReplicaPartition { + const STRING topic_; + const INT32 partition_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partition_, dst); + return written; + } + + bool operator==(const StopReplicaPartition& rhs) const { + return topic_ == rhs.topic_ && partition_ == rhs.partition_; + }; + + friend std::ostream& operator<<(std::ostream& os, const StopReplicaPartition& arg) { + return os << "{topic=" << arg.topic_ << ", partition=" << arg.partition_ << "}"; + } +}; + +class StopReplicaRequest : public Request { +public: + // v0 + StopReplicaRequest(INT32 controller_id, INT32 controller_epoch, BOOLEAN delete_partitions, + NULLABLE_ARRAY partitions) + : Request{RequestType::StopReplica}, controller_id_{controller_id}, + controller_epoch_{controller_epoch}, delete_partitions_{delete_partitions}, + partitions_{partitions} {}; + + bool operator==(const StopReplicaRequest& rhs) const { + return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && + controller_epoch_ == rhs.controller_epoch_ && + delete_partitions_ == rhs.delete_partitions_ && partitions_ == rhs.partitions_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(controller_id_, dst); + written += encoder.encode(controller_epoch_, dst); + written += encoder.encode(delete_partitions_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ + << ", delete_partitions=" << delete_partitions_ << ", partitions=" << partitions_ + << "}"; + } + +private: + const INT32 controller_id_; + const INT32 controller_epoch_; + const BOOLEAN delete_partitions_; + const NULLABLE_ARRAY partitions_; +}; + +// clang-format off +class StopReplicaPartitionV0Buffer : public CompositeBuffer {}; +class StopReplicaPartitionV0ArrayBuffer : public ArrayBuffer {}; + +class StopReplicaRequestV0Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(StopReplicaRequest, V0); +// clang-format on + +// === UPDATE METADATA (6) ===================================================== + +// uses MetadataPartitionState from LeaderAndIsr + +struct UpdateMetadataLiveBrokerEndpoint { + const INT32 port_; + const STRING host_; + const STRING listener_name_; + const INT16 security_protocol_type_; + + // v1 .. v2 + UpdateMetadataLiveBrokerEndpoint(INT32 port, STRING host, INT16 security_protocol_type) + : UpdateMetadataLiveBrokerEndpoint{port, host, "", security_protocol_type} {}; + + // v3 + UpdateMetadataLiveBrokerEndpoint(INT32 port, STRING host, STRING listener_name, + INT16 security_protocol_type) + : port_{port}, host_{host}, listener_name_{listener_name}, security_protocol_type_{ + security_protocol_type} {}; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(port_, dst); + written += encoder.encode(host_, dst); + if (encoder.apiVersion() >= 3) { + written += encoder.encode(listener_name_, dst); + } + written += encoder.encode(security_protocol_type_, dst); + return written; + } + + bool operator==(const UpdateMetadataLiveBrokerEndpoint& rhs) const { + return port_ == rhs.port_ && host_ == rhs.host_ && listener_name_ == rhs.listener_name_ && + security_protocol_type_ == rhs.security_protocol_type_; + }; + + friend std::ostream& operator<<(std::ostream& os, const UpdateMetadataLiveBrokerEndpoint& arg) { + return os << "{port=" << arg.port_ << ", host=" << arg.host_ + << ", listener_name=" << arg.listener_name_ + << ", security_protocol_type=" << arg.security_protocol_type_ << "}"; + } +}; + +struct UpdateMetadataLiveBroker { + const INT32 id_; + const NULLABLE_ARRAY endpoints_; + const NULLABLE_STRING rack_; // since v2 + + // v0 + // instead of having dedicated fields, store data as single UpdateMetadataLiveBrokerEndpoint (java + // client does it as well) + UpdateMetadataLiveBroker(INT32 id, STRING host, INT32 port) + : UpdateMetadataLiveBroker(id, {{UpdateMetadataLiveBrokerEndpoint{port, host, 0}}}){}; + + // v1 + UpdateMetadataLiveBroker(INT32 id, NULLABLE_ARRAY endpoints) + : UpdateMetadataLiveBroker{id, endpoints, absl::nullopt} {}; + + // v2 + UpdateMetadataLiveBroker(INT32 id, NULLABLE_ARRAY endpoints, + NULLABLE_STRING rack) + : id_{id}, endpoints_{endpoints}, rack_{rack} {}; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(id_, dst); + if (encoder.apiVersion() == 0) { + // we stored host+port as endpoint, but need to serialize properly + const UpdateMetadataLiveBrokerEndpoint& only_endpoint = (*endpoints_)[0]; + written += encoder.encode(only_endpoint.host_, dst); + written += encoder.encode(only_endpoint.port_, dst); + } else { + written += encoder.encode(endpoints_, dst); + if (encoder.apiVersion() >= 2) { + written += encoder.encode(rack_, dst); + } + } + return written; + } + + bool operator==(const UpdateMetadataLiveBroker& rhs) const { + return id_ == rhs.id_ && endpoints_ == rhs.endpoints_ && rack_ == rhs.rack_; + }; + + friend std::ostream& operator<<(std::ostream& os, const UpdateMetadataLiveBroker& arg) { + return os << "{id=" << arg.id_ << ", endpoints=" << arg.endpoints_ << ", rack=" << arg.rack_ + << "}"; + } +}; + +class UpdateMetadataRequest : public Request { +public: + // v0 + UpdateMetadataRequest(INT32 controller_id, INT32 controller_epoch, + NULLABLE_ARRAY partition_states, + NULLABLE_ARRAY live_brokers) + : Request{RequestType::UpdateMetadata}, controller_id_{controller_id}, + controller_epoch_{controller_epoch}, partition_states_{partition_states}, + live_brokers_{live_brokers} {}; + + bool operator==(const UpdateMetadataRequest& rhs) const { + return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && + controller_epoch_ == rhs.controller_epoch_ && + partition_states_ == rhs.partition_states_ && live_brokers_ == rhs.live_brokers_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(controller_id_, dst); + written += encoder.encode(controller_epoch_, dst); + written += encoder.encode(partition_states_, dst); + written += encoder.encode(live_brokers_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ + << ", partition_states=" << partition_states_ << ", live_brokers=" << live_brokers_ + << "}"; + } + +private: + const INT32 controller_id_; + const INT32 controller_epoch_; + const NULLABLE_ARRAY partition_states_; + const NULLABLE_ARRAY live_brokers_; +}; + +// clang-format off +class UpdateMetadataLiveBrokerEndpointV1Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerEndpointV1ArrayBuffer : public ArrayBuffer {}; +class UpdateMetadataLiveBrokerEndpointV3Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerEndpointV3ArrayBuffer : public ArrayBuffer {}; + +class UpdateMetadataLiveBrokerV0Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerV0ArrayBuffer : public ArrayBuffer {}; +class UpdateMetadataLiveBrokerV1Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerV1ArrayBuffer : public ArrayBuffer {}; +class UpdateMetadataLiveBrokerV2Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerV2ArrayBuffer : public ArrayBuffer {}; +class UpdateMetadataLiveBrokerV3Buffer : public CompositeBuffer {}; +class UpdateMetadataLiveBrokerV3ArrayBuffer : public ArrayBuffer {}; + +class UpdateMetadataRequestV0Buffer : public CompositeBuffer {}; +class UpdateMetadataRequestV1Buffer : public CompositeBuffer {}; +class UpdateMetadataRequestV2Buffer : public CompositeBuffer {}; +class UpdateMetadataRequestV3Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V0); +DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V1); +DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V2); +DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V3); +// clang-format on + +// === CONTROLLED SHUTDOWN (7) ================================================= + +// v0 is not documented +class ControlledShutdownRequest : public Request { +public: + // v1 + ControlledShutdownRequest(INT32 broker_id) + : Request{RequestType::ControlledShutdown}, broker_id_{broker_id} {}; + + bool operator==(const ControlledShutdownRequest& rhs) const { + return request_header_ == rhs.request_header_ && broker_id_ == rhs.broker_id_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(broker_id_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{broker_id=" << broker_id_ << "}"; + } + +private: + const INT32 broker_id_; +}; + +// clang-format off +class ControlledShutdownRequestV1Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(ControlledShutdownRequest, V1); +// clang-format on + +// === OFFSET COMMIT (8) ======================================================= + +struct OffsetCommitPartition { + const INT32 partition_; + const INT64 offset_; + const INT64 timestamp_; // only v1 + const NULLABLE_STRING metadata_; + + // v0 *and* v2 + OffsetCommitPartition(INT32 partition, INT64 offset, NULLABLE_STRING metadata) + : partition_{partition}, offset_{offset}, timestamp_{-1}, metadata_{metadata} {}; + + // v1 + OffsetCommitPartition(INT32 partition, INT64 offset, INT64 timestamp, NULLABLE_STRING metadata) + : partition_{partition}, offset_{offset}, timestamp_{timestamp}, metadata_{metadata} {}; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + if (encoder.apiVersion() == 1) { + written += encoder.encode(timestamp_, dst); + } + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitPartition& rhs) const { + return partition_ == rhs.partition_ && offset_ == rhs.offset_ && timestamp_ == rhs.timestamp_ && + metadata_ == rhs.metadata_; + }; + + friend std::ostream& operator<<(std::ostream& os, const OffsetCommitPartition& arg) { + return os << "{partition=" << arg.partition_ << ", offset=" << arg.offset_ + << ", timestamp=" << arg.timestamp_ << ", metadata=" << arg.metadata_ << "}"; + } +}; + +struct OffsetCommitTopic { + const STRING topic_; + const NULLABLE_ARRAY partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + friend std::ostream& operator<<(std::ostream& os, const OffsetCommitTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions_=" << arg.partitions_ << "}"; + } +}; + +class OffsetCommitRequest : public Request { +public: + // v0 + OffsetCommitRequest(STRING group_id, NULLABLE_ARRAY topics) + : OffsetCommitRequest(group_id, -1, "", -1, topics){}; + + // v1 + OffsetCommitRequest(STRING group_id, INT32 group_generation_id, STRING member_id, + NULLABLE_ARRAY topics) + : OffsetCommitRequest(group_id, group_generation_id, member_id, -1, topics){}; + + // v2 .. v3 + OffsetCommitRequest(STRING group_id, INT32 group_generation_id, STRING member_id, + INT64 retention_time, NULLABLE_ARRAY topics) + : Request{RequestType::OffsetCommit}, group_id_{group_id}, + group_generation_id_{group_generation_id}, member_id_{member_id}, + retention_time_{retention_time}, topics_{topics} {}; + + bool operator==(const OffsetCommitRequest& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && + group_generation_id_ == rhs.group_generation_id_ && member_id_ == rhs.member_id_ && + retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + if (encoder.apiVersion() >= 1) { + written += encoder.encode(group_generation_id_, dst); + written += encoder.encode(member_id_, dst); + } + if (encoder.apiVersion() >= 2) { + written += encoder.encode(retention_time_, dst); + } + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{group_id=" << group_id_ << ", group_generation_id=" << group_generation_id_ + << ", member_id=" << member_id_ << ", retention_time=" << retention_time_ + << ", topics=" << topics_ << "}"; + } + +private: + const STRING group_id_; + const INT32 group_generation_id_; // since v1 + const STRING member_id_; // since v1 + const INT64 retention_time_; // since v2 + const NULLABLE_ARRAY topics_; +}; + +// clang-format off +class OffsetCommitPartitionV0Buffer : public CompositeBuffer {}; +class OffsetCommitPartitionV0ArrayBuffer : public ArrayBuffer {}; +class OffsetCommitTopicV0Buffer : public CompositeBuffer {}; +class OffsetCommitTopicV0ArrayBuffer : public ArrayBuffer {}; + +class OffsetCommitPartitionV1Buffer : public CompositeBuffer {}; +class OffsetCommitPartitionV1ArrayBuffer : public ArrayBuffer {}; +class OffsetCommitTopicV1Buffer : public CompositeBuffer {}; +class OffsetCommitTopicV1ArrayBuffer : public ArrayBuffer {}; + +class OffsetCommitTopicV2ArrayBuffer : public OffsetCommitTopicV0ArrayBuffer {}; // v2 partition format is the same as v0 + +class OffsetCommitRequestV0Buffer : public CompositeBuffer {}; +class OffsetCommitRequestV1Buffer : public CompositeBuffer {}; +class OffsetCommitRequestV2Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); +DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); +DEFINE_REQUEST_PARSER(OffsetCommitRequest, V2); +// clang-format on + +// === OFFSET FETCH (9) ======================================================== + +struct OffsetFetchTopic { + const STRING topic_; + const NULLABLE_ARRAY partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetFetchTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + friend std::ostream& operator<<(std::ostream& os, const OffsetFetchTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; + } +}; + +class OffsetFetchRequest : public Request { +public: + // v0 .. v3 + OffsetFetchRequest(STRING group_id, NULLABLE_ARRAY topics) + : Request{RequestType::OffsetFetch}, group_id_{group_id}, topics_{topics} {}; + + bool operator==(const OffsetFetchRequest& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && + topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{group_id=" << group_id_ << ", topics=" << topics_ << "}"; + } + +private: + const STRING group_id_; + const NULLABLE_ARRAY topics_; +}; + +// clang-format off +class OffsetFetchPartitionV0ArrayBuffer : public ArrayBuffer {}; +class OffsetFetchTopicV0Buffer : public CompositeBuffer {}; +class OffsetFetchTopicV0ArrayBuffer : public ArrayBuffer {}; + +class OffsetFetchRequestV0Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(OffsetFetchRequest, V0); +// clang-format on + +// === API VERSIONS (18) ======================================================= + +class ApiVersionsRequest : public Request { +public: + // v0 .. v1 + ApiVersionsRequest() : Request{RequestType::ApiVersions} {}; + + bool operator==(const ApiVersionsRequest& rhs) const { + return request_header_ == rhs.request_header_; + }; + +protected: + size_t encodeDetails(Buffer::Instance&, EncodingContext&) const override { return 0; } + + std::ostream& printDetails(std::ostream& os) const override { return os << "{}"; } +}; + +// clang-format off +class ApiVersionsRequestV0Buffer : public NullBuffer {}; + +DEFINE_REQUEST_PARSER(ApiVersionsRequest, V0); +// clang-format on + +// === UNKNOWN REQUEST ========================================================= + +class UnknownRequest : public Request { +public: + UnknownRequest(const RequestHeader& request_header) : Request{request_header} {}; + +protected: + // this isn't the prettiest, as we have thrown away the data + // XXX(adam.kotwasinski) discuss capturing the data as-is, and simply putting it back + // this would add ability to forward unknown types of requests in cluster-proxy + size_t encodeDetails(Buffer::Instance&, EncodingContext&) const override { + throw EnvoyException("cannot serialize unknown request"); + } + + std::ostream& printDetails(std::ostream& out) const override { + return out << "{unknown request}"; + } +}; + +// ignores data until the end of request (contained in context_) +class SentinelConsumer : public Parser { +public: + SentinelConsumer(RequestContextSharedPtr context) : context_{context} {}; + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestContextSharedPtr context_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h new file mode 100644 index 0000000000000..f5c188e2dff59 --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -0,0 +1,31 @@ +#pragma once + +#include +#include + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +typedef int8_t INT8; +typedef int16_t INT16; +typedef int32_t INT32; +typedef int64_t INT64; +typedef uint32_t UINT32; +typedef bool BOOLEAN; + +typedef std::string STRING; +typedef absl::optional NULLABLE_STRING; + +typedef std::vector BYTES; +typedef absl::optional NULLABLE_BYTES; + +template using NULLABLE_ARRAY = absl::optional>; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h new file mode 100644 index 0000000000000..7d53597745508 --- /dev/null +++ b/source/extensions/filters/network/kafka/message.h @@ -0,0 +1,30 @@ +#pragma once + +#include + +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Abstract message + */ +class Message { +public: + virtual ~Message() = default; + + friend std::ostream& operator<<(std::ostream& out, const Message& arg) { return arg.print(out); } + +protected: + virtual std::ostream& print(std::ostream& os) const PURE; +}; + +typedef std::shared_ptr MessageSharedPtr; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h new file mode 100644 index 0000000000000..d9b11ae2b2ec2 --- /dev/null +++ b/source/extensions/filters/network/kafka/parser.h @@ -0,0 +1,48 @@ +#pragma once + +#include + +#include "common/common/logger.h" + +#include "extensions/filters/network/kafka/kafka_types.h" +#include "extensions/filters/network/kafka/message.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === PARSER ================================================================== + +class ParseResponse; + +class Parser : public Logger::Loggable { +public: + virtual ~Parser() = default; + + virtual ParseResponse parse(const char*& buffer, uint64_t& remaining) PURE; +}; + +typedef std::shared_ptr ParserSharedPtr; + +class ParseResponse { +public: + static ParseResponse stillWaiting() { return {nullptr, nullptr}; } + static ParseResponse nextParser(ParserSharedPtr next_parser) { return {next_parser, nullptr}; }; + static ParseResponse parsedMessage(MessageSharedPtr message) { return {nullptr, message}; }; + + bool hasData() const { return (next_parser_ != nullptr) || (message_ != nullptr); } + +private: + ParseResponse(ParserSharedPtr parser, MessageSharedPtr message) + : next_parser_{parser}, message_{message} {}; + +public: + ParserSharedPtr next_parser_; + MessageSharedPtr message_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc new file mode 100644 index 0000000000000..9ec598e30d0a7 --- /dev/null +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -0,0 +1,62 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "common/buffer/buffer_impl.h" + +#include "extensions/filters/network/kafka/kafka_protocol.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === DECODER ================================================================= + +void RequestDecoder::onData(Buffer::Instance& data) { + uint64_t num_slices = data.getRawSlices(nullptr, 0); + Buffer::RawSlice slices[num_slices]; + data.getRawSlices(slices, num_slices); + for (const Buffer::RawSlice& slice : slices) { + doParse(current_parser_, slice); + } +} + +void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice) { + const char* buffer = reinterpret_cast(slice.mem_); + uint64_t remaining = slice.len_; + while (remaining) { + ParseResponse result = parser->parse(buffer, remaining); + // this loop guarantees that parsers consuming 0 bytes also get processed + while (result.hasData()) { + if (!result.next_parser_) { + + // next parser is not present, so we have finished parsing a message + MessageSharedPtr message = result.message_; + ENVOY_LOG(trace, "parsed message: {}", *message); + for (auto& callback : callbacks_) { + callback->onMessage(result.message_); + } + + // we finished parsing this request, start anew + parser = std::make_shared(parser_resolver_); + } else { + parser = result.next_parser_; + } + result = parser->parse(buffer, remaining); + } + } +} + +// === ENCODER ================================================================= + +void RequestEncoder::encode(const Request& message) { + EncodingContext encoder{message.apiVersion()}; + Buffer::OwnedImpl data_buffer; + INT32 data_len = encoder.encode(message, data_buffer); // encode data computing data length + encoder.encode(data_len, output_); // encode data length into result + output_.add(data_buffer); // copy data into result +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h new file mode 100644 index 0000000000000..ebd797ea1a75a --- /dev/null +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -0,0 +1,69 @@ +#pragma once + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" + +#include "extensions/filters/network/kafka/codec.h" +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === DECODER ================================================================= + +/** + * Invoked when request is successfully decoded + */ +class RequestCallback { +public: + virtual ~RequestCallback() = default; + + virtual void onMessage(MessageSharedPtr) PURE; +}; + +typedef std::shared_ptr RequestCallbackSharedPtr; + +/** + * Decoder that decodes Kafka requests + * When a request is decoded, the callbacks are notified, in order + * + * This decoder uses chain of parsers to parse fragments of a request + * Each parser along the line returns the fully parsed message or the next parser + */ +class RequestDecoder : public MessageDecoder, public Logger::Loggable { +public: + RequestDecoder(const RequestParserResolver parserResolver, + const std::vector callbacks) + : parser_resolver_{parserResolver}, callbacks_{callbacks}, + current_parser_{new RequestStartParser(parser_resolver_)} {}; + + void onData(Buffer::Instance& data); + +private: + void doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice); + + const RequestParserResolver parser_resolver_; + const std::vector callbacks_; + + ParserSharedPtr current_parser_; +}; + +// === ENCODER ================================================================= + +class RequestEncoder : public MessageEncoder { +public: + RequestEncoder(Buffer::Instance& output) : output_(output) {} + void encode(const Request& message) override; + +private: + Buffer::Instance& output_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h new file mode 100644 index 0000000000000..afc748cb7bbe0 --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization.h @@ -0,0 +1,772 @@ +#pragma once + +#include +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// ============================================================================= +// === DESERIALIZERS =========================================================== +// ============================================================================= + +/** + * The general idea of Buffer is that it can be feed-ed data until it is ready + * When true == ready(), it is safe to call get() + * Further feed()-ing should have no effect on a buffer + * (should return 0 and not move buffer/remaining) + */ + +// === ABSTRACT DESERIALIZER =================================================== + +template class Deserializer { +public: + virtual ~Deserializer() = default; + + virtual size_t feed(const char*& buffer, uint64_t& remaining) PURE; + virtual bool ready() const PURE; + virtual T get() const PURE; +}; + +// === INT BUFFERS ============================================================= + +/** + * The values are encoded in network byte order (big-endian). + */ +template class IntBuffer : public Deserializer { +public: + IntBuffer() : written_{0}, ready_(false){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t available = std::min(sizeof(buf_) - written_, remaining); + memcpy(buf_ + written_, buffer, available); + written_ += available; + + if (written_ == sizeof(buf_)) { + ready_ = true; + } + + buffer += available; + remaining -= available; + + return available; + } + + bool ready() const { return ready_; } + +protected: + char buf_[sizeof(T) / sizeof(char)]; + size_t written_; + bool ready_; +}; + +class Int8Buffer : public IntBuffer { +public: + INT8 get() const { + INT8 result; + memcpy(&result, buf_, sizeof(result)); + return result; + } +}; + +class Int16Buffer : public IntBuffer { +public: + INT16 get() const { + INT16 result; + memcpy(&result, buf_, sizeof(result)); + return be16toh(result); + } +}; + +class Int32Buffer : public IntBuffer { +public: + INT32 get() const { + INT32 result; + memcpy(&result, buf_, sizeof(result)); + return be32toh(result); + } +}; + +class UInt32Buffer : public IntBuffer { +public: + UINT32 get() const { + UINT32 result; + memcpy(&result, buf_, sizeof(result)); + return be32toh(result); + } +}; + +class Int64Buffer : public IntBuffer { +public: + INT64 get() const { + INT64 result; + memcpy(&result, buf_, sizeof(result)); + return be64toh(result); + } +}; + +// === BOOL BUFFER ============================================================= + +/** + * Represents a boolean value in a byte. + * Values 0 and 1 are used to represent false and true respectively. + * When reading a boolean value, any non-zero value is considered true. + */ +class BoolBuffer : public Deserializer { +public: + BoolBuffer(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { return buffer_.feed(buffer, remaining); } + + bool ready() const { return buffer_.ready(); } + + BOOLEAN get() const { return 0 != buffer_.get(); } + +private: + Int8Buffer buffer_; +}; + +// === STRING BUFFER =========================================================== + +/** + * Represents a sequence of characters. + * First the length N is given as an INT16. + * Then N bytes follow which are the UTF-8 encoding of the character sequence. + * Length must not be negative. + */ +class StringBuffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid STRING length: {}", required_)); + } + length_consumed_ = true; + } + + const size_t data_consumed = std::min(required_, remaining); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, buffer, data_consumed); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + STRING get() const { return std::string(data_buf_.begin(), data_buf_.end()); } + +private: + Int16Buffer length_buf_; + bool length_consumed_{false}; + + INT16 required_; + std::vector data_buf_; + + bool ready_{false}; +}; + +/** + * Represents a sequence of characters or null. + * For non-null strings, first the length N is given as an INT16. + * Then N bytes follow which are the UTF-8 encoding of the character sequence. + * A null value is encoded with length of -1 and there are no following bytes. + */ +class NullableStringBuffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } + if (required_ == NULL_STRING_LENGTH) { + ready_ = true; + } + if (required_ < NULL_STRING_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_STRING length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, remaining); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, buffer, data_consumed); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + NULLABLE_STRING get() const { + return required_ >= 0 ? absl::make_optional(std::string(data_buf_.begin(), data_buf_.end())) + : absl::nullopt; + } + +private: + constexpr static INT16 NULL_STRING_LENGTH{-1}; + + Int16Buffer length_buf_; + bool length_consumed_{false}; + + INT16 required_; + std::vector data_buf_; + + bool ready_{false}; +}; + +// === BYTES BUFFERS =========================================================== + +/** + * Represents a raw sequence of bytes or null. + * For non-null values, first the length N is given as an INT32. Then N bytes follow. + * A null value is encoded with length of -1 and there are no following bytes. + */ + +/** + * This buffer ignores the data fed, the only result is the number of bytes ignored + */ +class NullableBytesIgnoringBuffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_max_ = length_buf_.get(); + required_ = length_buf_.get(); + + if (required_ == NULL_BYTES_LENGTH) { + ready_ = true; + } + if (required_ < NULL_BYTES_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, remaining); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + /** + * Returns length of ignored array, or -1 if that was null + */ + INT32 get() const { return required_max_; } + +private: + constexpr static INT32 NULL_BYTES_LENGTH{-1}; + + Int32Buffer length_buf_; + bool length_consumed_{false}; + INT32 required_max_; + INT32 required_; + bool ready_{false}; +}; + +/** + * This buffer captures the data fed + */ +class NullableBytesCapturingBuffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } + if (required_ == NULL_BYTES_LENGTH) { + ready_ = true; + } + if (required_ < NULL_BYTES_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, remaining); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, buffer, data_consumed); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + NULLABLE_BYTES get() const { + if (NULL_BYTES_LENGTH == required_) { + return absl::nullopt; + } else { + return {data_buf_}; + } + } + +private: + constexpr static INT32 NULL_BYTES_LENGTH{-1}; + + Int32Buffer length_buf_; + bool length_consumed_{false}; + INT32 required_; + + std::vector data_buf_; + bool ready_{false}; +}; + +// === COMPOSITE BUFFER ======================================================== + +/** + * Composes several buffers into one. + * The returned value is constructed via { buffer1.get(), buffer2.get() ... } + */ +template class CompositeBuffer; + +template class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer1_.ready(); } + RT get() const { return {buffer1_.get()}; } + +protected: + T1 buffer1_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer2_.ready(); } + RT get() const { return {buffer1_.get(), buffer2_.get()}; } + +protected: + T1 buffer1_; + T2 buffer2_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + consumed += buffer3_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer3_.ready(); } + RT get() const { return {buffer1_.get(), buffer2_.get(), buffer3_.get()}; } + +protected: + T1 buffer1_; + T2 buffer2_; + T3 buffer3_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + consumed += buffer3_.feed(buffer, remaining); + consumed += buffer4_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer4_.ready(); } + RT get() const { return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get()}; } + +protected: + T1 buffer1_; + T2 buffer2_; + T3 buffer3_; + T4 buffer4_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + consumed += buffer3_.feed(buffer, remaining); + consumed += buffer4_.feed(buffer, remaining); + consumed += buffer5_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer5_.ready(); } + RT get() const { + return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get(), buffer5_.get()}; + } + +protected: + T1 buffer1_; + T2 buffer2_; + T3 buffer3_; + T4 buffer4_; + T5 buffer5_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + consumed += buffer3_.feed(buffer, remaining); + consumed += buffer4_.feed(buffer, remaining); + consumed += buffer5_.feed(buffer, remaining); + consumed += buffer6_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer6_.ready(); } + RT get() const { + return {buffer1_.get(), buffer2_.get(), buffer3_.get(), + buffer4_.get(), buffer5_.get(), buffer6_.get()}; + } + +protected: + T1 buffer1_; + T2 buffer2_; + T3 buffer3_; + T4 buffer4_; + T5 buffer5_; + T6 buffer6_; +}; + +template +class CompositeBuffer : public Deserializer { +public: + CompositeBuffer(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += buffer1_.feed(buffer, remaining); + consumed += buffer2_.feed(buffer, remaining); + consumed += buffer3_.feed(buffer, remaining); + consumed += buffer4_.feed(buffer, remaining); + consumed += buffer5_.feed(buffer, remaining); + consumed += buffer6_.feed(buffer, remaining); + consumed += buffer7_.feed(buffer, remaining); + consumed += buffer8_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return buffer8_.ready(); } + RT get() const { + return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get(), + buffer5_.get(), buffer6_.get(), buffer7_.get(), buffer8_.get()}; + } + +protected: + T1 buffer1_; + T2 buffer2_; + T3 buffer3_; + T4 buffer4_; + T5 buffer5_; + T6 buffer6_; + T7 buffer7_; + T8 buffer8_; +}; + +// === ARRAY BUFFER ============================================================ + +/** + * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. + * STRING) or a structure. First, the length N is given as an INT32. Then N instances of type T + * follow. A null array is represented with a length of -1. + */ + +template class ArrayBuffer : public Deserializer> { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + children_ = std::vector(required_); + } + if (required_ == NULL_ARRAY_LENGTH) { + ready_ = true; + } + if (required_ < NULL_ARRAY_LENGTH) { + throw EnvoyException(fmt::format("invalid array length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + size_t child_consumed{0}; + for (CT& child : children_) { + child_consumed += child.feed(buffer, remaining); + } + + bool children_ready_ = true; + for (CT& child : children_) { + children_ready_ &= child.ready(); + } + ready_ = children_ready_; + + return length_consumed + child_consumed; + } + + bool ready() const { return ready_; } + + NULLABLE_ARRAY get() const { + if (NULL_ARRAY_LENGTH != required_) { + std::vector result{}; + result.reserve(children_.size()); + for (const CT& child : children_) { + const RT child_result = child.get(); + result.push_back(child_result); + } + return {result}; + } else { + return absl::nullopt; + } + } + +private: + constexpr static INT32 NULL_ARRAY_LENGTH{-1}; + + Int32Buffer length_buf_; + bool length_consumed_{false}; + INT32 required_; + std::vector children_; + bool children_setup_{false}; + bool ready_{false}; +}; + +// === NULL BUFFER ============================================================= + +/** + * Consumes no bytes, used as placeholder + */ +template class NullBuffer : public Deserializer { +public: + size_t feed(const char*&, uint64_t&) { return 0; } + + bool ready() const { return true; } + + RT get() const { return {}; } +}; + +// ============================================================================= +// === ENCODER HELPER ========================================================== +// ============================================================================= + +/** + * Encodes provided argument in Kafka format + * In case of primitive types, this is done explicitly as per spec + * In case of composite types, this is done by calling 'encode' on provided argument + */ + +class EncodingContext { +public: + EncodingContext(INT16 api_version) : api_version_{api_version} {}; + + template size_t encode(const T& arg, Buffer::Instance& dst); + + template size_t encode(const NULLABLE_ARRAY& arg, Buffer::Instance& dst); + + INT16 apiVersion() const { return api_version_; } + +private: + const INT16 api_version_; +}; + +template inline size_t EncodingContext::encode(const T& arg, Buffer::Instance& dst) { + return arg.encode(dst, *this); +} + +template <> inline size_t EncodingContext::encode(const INT8& arg, Buffer::Instance& dst) { + dst.add(&arg, sizeof(INT8)); + return sizeof(INT8); +} + +#define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ + template <> inline size_t EncodingContext::encode(const TYPE& arg, Buffer::Instance& dst) { \ + TYPE val = CONVERTER(arg); \ + dst.add(&val, sizeof(TYPE)); \ + return sizeof(TYPE); \ + } + +ENCODE_NUMERIC_TYPE(INT16, htobe16); +ENCODE_NUMERIC_TYPE(INT32, htobe32); +ENCODE_NUMERIC_TYPE(UINT32, htobe32); +ENCODE_NUMERIC_TYPE(INT64, htobe64); + +template <> inline size_t EncodingContext::encode(const BOOLEAN& arg, Buffer::Instance& dst) { + INT8 val = arg; + dst.add(&val, sizeof(INT8)); + return sizeof(INT8); +} + +template <> inline size_t EncodingContext::encode(const STRING& arg, Buffer::Instance& dst) { + INT16 string_length = arg.length(); + size_t header_length = encode(string_length, dst); + dst.add(arg.c_str(), string_length); + return header_length + string_length; +} + +template <> +inline size_t EncodingContext::encode(const NULLABLE_STRING& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + return encode(*arg, dst); + } else { + INT16 len = -1; + return encode(len, dst); + } +} + +template <> inline size_t EncodingContext::encode(const BYTES& arg, Buffer::Instance& dst) { + INT32 data_length = arg.size(); + size_t header_length = encode(data_length, dst); + dst.add(arg.data(), arg.size()); + return header_length + data_length; +} + +template <> +inline size_t EncodingContext::encode(const NULLABLE_BYTES& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + return encode(*arg, dst); + } else { + INT32 len = -1; + return encode(len, dst); + } +} + +template +size_t EncodingContext::encode(const NULLABLE_ARRAY& arg, Buffer::Instance& dst) { + if (arg.has_value()) { + INT32 len = arg->size(); + size_t header_length = encode(len, dst); + size_t written{0}; + for (const T& el : *arg) { + written += encode(el, dst); + } + return header_length + written; + } else { + INT32 len = -1; + return encode(len, dst); + } +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/well_known_names.h b/source/extensions/filters/network/well_known_names.h index 6a68c32223c41..466178b37ed8f 100644 --- a/source/extensions/filters/network/well_known_names.h +++ b/source/extensions/filters/network/well_known_names.h @@ -30,6 +30,8 @@ class NetworkFilterNameValues { const std::string TcpProxy = "envoy.tcp_proxy"; // Authorization filter const std::string ExtAuthorization = "envoy.ext_authz"; + // Kafka filter + const std::string Kafka = "envoy.filters.network.kafka"; // Thrift proxy filter const std::string ThriftProxy = "envoy.filters.network.thrift_proxy"; // Role based access control filter diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD new file mode 100644 index 0000000000000..ab85bf20aaced --- /dev/null +++ b/test/extensions/filters/network/kafka/BUILD @@ -0,0 +1,42 @@ +licenses(["notice"]) # Apache 2 + +load( + "//bazel:envoy_build_system.bzl", + "envoy_package", +) +load( + "//test/extensions:extensions_build_system.bzl", + "envoy_extension_cc_test", +) + +envoy_package() + +envoy_extension_cc_test( + name = "serialization_test", + srcs = ["serialization_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:serialization_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "kafka_request_test", + srcs = ["kafka_request_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_test", + srcs = ["request_codec_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc new file mode 100644 index 0000000000000..c57361c5edbce --- /dev/null +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -0,0 +1,173 @@ +#include "extensions/filters/network/kafka/kafka_request.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +TEST(RequestParserResolver, ShouldReturnSentinelIfRequestTypeIsNotRegistered) { + // given + RequestParserResolver testee{{}}; + RequestContextSharedPtr context{new RequestContext{}}; + + // when + ParserSharedPtr result = testee.createParser(0, 1, context); // api_key = 0 was not registered + + // then + ASSERT_NE(result, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); +} + +TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) { + // given + GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { + return std::make_shared(arg); + }; + RequestParserResolver testee{{{0, {0, 1}, generator}}}; + RequestContextSharedPtr context{new RequestContext{}}; + + // when + ParserSharedPtr result = + testee.createParser(0, 2, context); // api_version = 2 was not registered (0 & 1 were) + + // then + ASSERT_NE(result, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); +} + +TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { + // given + GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { + return std::make_shared(arg); + }; + RequestParserResolver testee{{{0, {0, 1, 2, 3}, generator}}}; + RequestContextSharedPtr context{new RequestContext{}}; + + // when + ParserSharedPtr result = testee.createParser(0, 3, context); + + // then + ASSERT_NE(result, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); +} + +class BufferBasedTest : public testing::Test { +public: + Buffer::OwnedImpl& buffer() { return buffer_; } + + const char* getBytes() { + uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); + Buffer::RawSlice slices[num_slices]; + buffer_.getRawSlices(slices, num_slices); + return reinterpret_cast((slices[0]).mem_); + } + +private: + Buffer::OwnedImpl buffer_; + EncodingContext encoder_{-1}; +}; + +TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { + // given + RequestStartParser testee{RequestParserResolver{{}}}; + + INT32 request_len = 1234; + encoder_.encode(request_len, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = 1024; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); +} + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver() : RequestParserResolver{{}} {}; + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(INT16, INT16, RequestContextSharedPtr)); +}; + +TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { + // given + const MockRequestParserResolver parser_resolver; + const ParserSharedPtr parser{new ApiVersionsRequestV0Parser{nullptr}}; + EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); + + const INT32 request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + RequestHeaderParser testee{parser_resolver, context}; + + const INT16 api_key{1}; + const INT16 api_version{2}; + const INT32 correlation_id{10}; + const NULLABLE_STRING client_id{"aaa"}; + size_t written = 0; + written += encoder_.encode(api_key, buffer()); + written += encoder_.encode(api_version, buffer()); + written += encoder_.encode(correlation_id, buffer()); + written += encoder_.encode(client_id, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = 100000; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, parser); + ASSERT_EQ(result.message_, nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - written); + ASSERT_EQ(remaining, orig_remaining - written); + + const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; + ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); +} + +TEST_F(BufferBasedTest, SentinelConsumerShouldConsumeDataUntilEndOfRequest) { + // given + const INT32 request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + SentinelConsumer testee{context}; + + const BYTES garbage(request_len * 2); + encoder_.encode(garbage, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = request_len * 2; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); + ASSERT_EQ(remaining, orig_remaining - request_len); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc new file mode 100644 index 0000000000000..3e834ca76751c --- /dev/null +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -0,0 +1,538 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class RequestDecoderTest : public testing::Test { +public: + Buffer::OwnedImpl buffer_; + + template std::shared_ptr serializeAndDeserialize(T request); +}; + +class MockMessageListener : public RequestCallback { +public: + MOCK_METHOD1(onMessage, void(MessageSharedPtr)); +}; + +template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { + RequestEncoder serializer{buffer_}; + serializer.encode(request); + + std::shared_ptr mock_listener = std::make_shared(); + RequestDecoder testee{RequestParserResolver::KAFKA_0_11, {mock_listener}}; + + MessageSharedPtr receivedMessage; + EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); + + testee.onData(buffer_); + + return std::dynamic_pointer_cast(receivedMessage); +}; + +// === PRODUCE (0) ============================================================= + +TEST_F(RequestDecoderTest, shouldParseProduceRequestV0toV2) { + // given + NULLABLE_ARRAY topics{ + {{"t1", {{{0, NULLABLE_BYTES(100)}, {1, NULLABLE_BYTES(200)}}}}, + {"t2", {{{0, NULLABLE_BYTES(300)}}}}}}; + FatProduceRequest request{10, 20, topics}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseProduceRequestV3) { + // given + NULLABLE_ARRAY topics{ + {{"t1", {{{0, NULLABLE_BYTES(100)}, {1, NULLABLE_BYTES(200)}}}}, + {"t2", {{{0, NULLABLE_BYTES(300)}}}}}}; + // transaction_id in V3 + FatProduceRequest request{"txid", 10, 20, topics}; + request.apiVersion() = 3; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === FETCH (1) =============================================================== + +TEST_F(RequestDecoderTest, shouldParseFetchRequestV0toV2) { + // given + FetchRequest request{1, + 1000, + 10, + {{ + {"topic1", {{{10, 20, 2000}}}}, + {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, + {"topic1", {{{13, 23, 2003}}}}, + }}}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseFetchRequestV3) { + // given + FetchRequest request{1, + 1000, + 10, + 20, // max_bytes in V3 + {{ + {"topic1", {{{10, 20, 2000}}}}, + {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, + {"topic1", {{{13, 23, 2003}}}}, + }}}; + request.apiVersion() = 3; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseFetchRequestV4) { + // given + FetchRequest request{1, + 1000, + 10, + 20, + 2, // isolation level in V4 + {{ + {"topic1", {{{10, 20, 2000}}}}, + {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, + {"topic1", {{{13, 23, 2003}}}}, + }}}; + request.apiVersion() = 4; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseFetchRequestV5) { + // given + FetchRequest request{1, + 1000, + 10, + 20, + 2, + {{ + // log_start_offset_ in partition data in V5 + {"topic1", {{{10, 20, 1000, 2000}}}}, + {"topic1", {{{11, 21, 1001, 2001}, {12, 22, 1002, 2002}}}}, + {"topic1", {{{13, 23, 1003, 2003}}}}, + }}}; + request.apiVersion() = 5; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === LIST OFFSETS (2) ======================================================== + +TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV0) { + // given + ListOffsetsRequest request{10, + {{ + // partition contains max_num_offsets in v0 only + {"topic1", {{{1, 1000, 10}, {2, 2000, 20}}}}, + {"topic2", {{{3, 3000, 30}}}}, + }}}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV1) { + // given + ListOffsetsRequest request{10, + {{ + // max_num_offsets removed in v1 + {"topic1", {{{1, 1000}, {2, 2000}}}}, + {"topic2", {{{3, 3000}}}}, + }}}; + request.apiVersion() = 1; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV2) { + // given + ListOffsetsRequest request{10, + 2, // isolation level in v2 + {{ + {"topic1", {{{1, 1000}, {2, 2000}}}}, + {"topic2", {{{3, 3000}}}}, + }}}; + request.apiVersion() = 2; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === METADATA (3) ============================================================ + +TEST_F(RequestDecoderTest, shouldParseMetadataRequestV0toV3) { + // given + MetadataRequest request{{{"t1", "t2", "t3"}}}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseMetadataRequestV4) { + // given + MetadataRequest request{{{"t1", "t2", "t3"}}, true}; + request.apiVersion() = 4; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === LEADER-AND-ISR (4) ====================================================== + +TEST_F(RequestDecoderTest, shouldParseLeaderAndIsrRequestV0) { + // given + MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; + MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; + NULLABLE_ARRAY partition_states{{ps1, ps2}}; + NULLABLE_ARRAY live_leaders{ + {{1, "host1", 9092}, {2, "host2", 9093}, {3, "host3", 9094}}}; + LeaderAndIsrRequest request{20, 1000, partition_states, live_leaders}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === STOP REPLICA (5) ======================================================== + +TEST_F(RequestDecoderTest, shouldParseStopReplicaRequestV0) { + // given + NULLABLE_ARRAY partitions{{{"t1", 0}, {"t2", 3}}}; + StopReplicaRequest request{10, 1000, true, partitions}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === UPDATE METADATA (6) ===================================================== + +TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV0) { + // given + MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; + MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; + NULLABLE_ARRAY partition_states{{ps1, ps2}}; + NULLABLE_ARRAY live_brokers{ + {{1, "host1", 9092}, {2, "host2", 9093}, {3, "host3", 9094}}}; + UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV1) { + // given + MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; + MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; + NULLABLE_ARRAY partition_states{{ps1, ps2}}; + NULLABLE_ARRAY live_brokers{ + {{1, {{{9092, "h1", 0}, {9093, "h1", 1}}}}, // endpoints added in v1, host & port removed + {2, {{{9092, "h2", 2}, {9093, "h2", 3}}}}}}; + UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; + request.apiVersion() = 1; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV2) { + // given + MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; + MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; + NULLABLE_ARRAY partition_states{{ps1, ps2}}; + NULLABLE_ARRAY live_brokers{ + {{1, {{{9092, "h1", 0}, {9093, "h1", 1}}}, "rack1"}, // rack added in v2 + {2, {{{9092, "h2", 2}, {9093, "h2", 3}}}, absl::nullopt}}}; + UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; + request.apiVersion() = 2; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV3) { + // given + MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; + MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; + NULLABLE_ARRAY partition_states{{ps1, ps2}}; + NULLABLE_ARRAY live_brokers{ + {{1, + {{{9092, "h1", "plain", 0}, {9093, "h1", "ssl", 1}}}, + "rack1"}, // listener_name added in v2 + {2, {{{9092, "h2", "name3", 2}, {9093, "h2", "name4", 3}}}, absl::nullopt}}}; + UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; + request.apiVersion() = 3; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === CONTROLLED SHUTDOWN (7) ================================================= + +TEST_F(RequestDecoderTest, shouldParseControlledShutdownRequestV1) { + // given + ControlledShutdownRequest request{1}; + request.apiVersion() = 1; + request.correlationId() = 42; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === OFFSET COMMIT (8) ======================================================= + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { + // given + NULLABLE_ARRAY topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; + OffsetCommitRequest request{"group_id", topics}; + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { + // given + // partitions have timestamp in v1 only + NULLABLE_ARRAY topics{ + {{"topic1", {{{0, 10, 100, "m1"}, {2, 20, 101, "m2"}}}}, {"topic2", {{{3, 30, 102, "m3"}}}}}}; + OffsetCommitRequest request{"group_id", + 40, // group_generation_id + "member_id", // member_id + topics}; + request.apiVersion() = 1; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV2toV3) { + // given + NULLABLE_ARRAY topics{ + {{"topic1", {{{0, 10, "m1"}, {2, 20, "m2"}}}}, {"topic2", {{{3, 30, "m3"}}}}}}; + OffsetCommitRequest request{"group_id", 1234, "member", + 2345, // retention_time + topics}; + request.apiVersion() = 2; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === OFFSET FETCH (9) ======================================================== + +TEST_F(RequestDecoderTest, shouldParseOffsetFetchRequestV0toV3) { + // given + OffsetFetchRequest request{"group_id", {{{"topic1", {{0, 1, 2}}}, {"topic2", {{3, 4}}}}}}; + + request.apiVersion() = 0; + request.correlationId() = 10; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === API VERSIONS (18) ======================================================= + +TEST_F(RequestDecoderTest, shouldParseApiVersionsRequestV0toV1) { + // given + ApiVersionsRequest request{}; + request.apiVersion() = 0; + request.correlationId() = 42; + request.clientId() = "client-id"; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +// === UNKNOWN REQUEST ========================================================= + +TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { + // given + RequestEncoder serializer{buffer_}; + ApiVersionsRequest request{}; + request.apiVersion() = 1; + request.correlationId() = 42; + request.clientId() = "client-id"; + + serializer.encode(request); + + std::shared_ptr mock_listener = std::make_shared(); + RequestParserResolver parser_resolver{{}}; // we do not accept any kind of message here + RequestDecoder testee{parser_resolver, {mock_listener}}; + + MessageSharedPtr rev; + EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&rev)); + + // when + testee.onData(buffer_); + + // then + auto received = std::dynamic_pointer_cast(rev); + ASSERT_NE(received, nullptr); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc new file mode 100644 index 0000000000000..27c0e627b5be4 --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -0,0 +1,405 @@ +#include "extensions/filters/network/kafka/serialization.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// === EMPTY (FRESHLY INITIALIZED) BUFFER TESTS ================================ + +// freshly created buffers should not be ready +#define TEST_EmptyBufferShouldNotBeReady(BufferClass) \ + TEST(BufferClass, EmptyBufferShouldNotBeReady) { \ + const BufferClass testee{}; \ + ASSERT_EQ(testee.ready(), false); \ + } + +TEST_EmptyBufferShouldNotBeReady(Int8Buffer); +TEST_EmptyBufferShouldNotBeReady(Int16Buffer); +TEST_EmptyBufferShouldNotBeReady(Int32Buffer); +TEST_EmptyBufferShouldNotBeReady(UInt32Buffer); +TEST_EmptyBufferShouldNotBeReady(Int64Buffer); +TEST_EmptyBufferShouldNotBeReady(BoolBuffer); +TEST_EmptyBufferShouldNotBeReady(StringBuffer); +TEST_EmptyBufferShouldNotBeReady(NullableStringBuffer); +TEST_EmptyBufferShouldNotBeReady(NullableBytesIgnoringBuffer); +TEST(CompositeBuffer, EmptyBufferShouldNotBeReady) { + // given + const CompositeBuffer testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} +TEST(ArrayBuffer, EmptyBufferShouldNotBeReady) { + // given + const ArrayBuffer testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +// Null buffer is a special case, it's always ready and can provide results via 0-arg ctor +TEST(NullBuffer, EmptyBufferShouldBeReady) { + // given + const NullBuffer testee{}; + // when, then + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), 0); +} + +// === SERIALIZATION / DESERIALIZATION TESTS =================================== + +EncodingContext encoder{-1}; // context is not used when serializing primitive types + +const char* getRawData(const Buffer::OwnedImpl& buffer) { + uint64_t num_slices = buffer.getRawSlices(nullptr, 0); + Buffer::RawSlice slices[num_slices]; + buffer.getRawSlices(slices, num_slices); + return reinterpret_cast((slices[0]).mem_); +} + +// exactly what is says on the tin: +// 1. serialize expected using Encoder +// 2. deserialize byte array using testee buffer +// 3. verify result = expected +// 4. verify that data pointer moved correct amount +// 5. feed testee more data +// 6. verify that nothing more was consumed +template +void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + const size_t written = encoder.encode(expected, buffer); + + uint64_t remaining = + 10 * + written; // tell parser that there is more data, it should never consume more than written + const uint64_t orig_remaining = remaining; + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + const size_t consumed = testee.feed(data, remaining); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); +} + +// does the same thing as the above test, +// but instead of providing whole data at one, it provides it in N one-byte chunks +// this verifies if buffer keeps state properly +template +void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + const size_t written = encoder.encode(expected, buffer); + + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + size_t consumed = 0; + for (size_t i = 0; i < written; ++i) { + uint64_t data_size = 1; + consumed += testee.feed(data, data_size); + ASSERT_EQ(data_size, 0); + } + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + ASSERT_EQ(data, orig_data + consumed); + + // when - 2 + uint64_t remaining = 1024; + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, 1024); +} + +template void serializeThenDeserializeAndCheckEquality(AT expected) { + serializeThenDeserializeAndCheckEqualityInOneGo(expected); + serializeThenDeserializeAndCheckEqualityWithChunks(expected); +} + +// === NUMERIC BUFFERS ========================================================= + +// macroed out test for numeric buffers +#define TEST_BufferShouldDeserialize(BufferClass, DataClass, Value) \ + TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ + /* given */ \ + const DataClass value = Value; \ + serializeThenDeserializeAndCheckEquality(value); \ + } + +TEST_BufferShouldDeserialize(Int8Buffer, INT8, 42); +TEST_BufferShouldDeserialize(Int16Buffer, INT16, 42); +TEST_BufferShouldDeserialize(Int32Buffer, INT32, 42); +TEST_BufferShouldDeserialize(UInt32Buffer, UINT32, 42); +TEST_BufferShouldDeserialize(Int64Buffer, INT64, 42); +TEST_BufferShouldDeserialize(BoolBuffer, BOOLEAN, true); + +// === (NULLABLE) STRING BUFFER ================================================ + +TEST(StringBuffer, ShouldDeserialize) { + const STRING value = "sometext"; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(StringBuffer, ShouldDeserializeEmptyString) { + const STRING value = ""; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(StringBuffer, ShouldThrowOnInvalidLength) { + // given + StringBuffer testee; + Buffer::OwnedImpl buffer; + + INT16 len = -1; + encoder.encode(len, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +TEST(NullableStringBuffer, ShouldDeserializeString) { + // given + const NULLABLE_STRING value{"sometext"}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringBuffer, ShouldDeserializeEmptyString) { + // given + const NULLABLE_STRING value{""}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringBuffer, ShouldDeserializeAbsentString) { + // given + const NULLABLE_STRING value = absl::nullopt; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { + // given + NullableStringBuffer testee; + Buffer::OwnedImpl buffer; + + INT16 len = -2; // -1 is OK for NULLABLE_STRING + encoder.encode(len, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +// === NULLABLE BYTES IGNORING BUFFER ========================================== + +TEST(NullableBytesIgnoringBuffer, ShouldDeserialize) { + // given + NullableBytesIgnoringBuffer testee; + + const INT32 bytes_to_ignore = 100; + NULLABLE_BYTES bytes = {std::vector(bytes_to_ignore)}; + + Buffer::OwnedImpl buffer; + size_t written = encoder.encode(bytes, buffer); + + uint64_t remaining = written * 10; + const uint64_t orig_remaining = remaining; + + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + const size_t consumed = testee.feed(data, remaining); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), bytes_to_ignore); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); +} + +TEST(NullableBytesIgnoringBuffer, ShouldDeserializeNullBytes) { + // given + NullableBytesIgnoringBuffer testee; + + const INT32 bytes_length = -1; + + Buffer::OwnedImpl buffer; + const size_t written = encoder.encode(bytes_length, buffer); + + uint64_t remaining = written * 10; + const uint64_t orig_remaining = remaining; + + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + const size_t consumed = testee.feed(data, remaining); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), bytes_length); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); +} + +TEST(NullableBytesIgnoringBuffer, ShouldThrowOnInvalidLength) { + // given + NullableBytesIgnoringBuffer testee; + Buffer::OwnedImpl buffer; + + const INT32 bytes_length = -2; // -1 is OK for NULLABLE_BYTES + encoder.encode(bytes_length, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +// === NULLABLE BYTES CAPTURING BUFFER ========================================= + +TEST(NullableBytesCapturingBuffer, ShouldDeserialize) { + const NULLABLE_BYTES value{{'a', 'b', 'c', 'd'}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesCapturingBuffer, ShouldDeserializeEmptyBytes) { + const NULLABLE_BYTES value{{}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesCapturingBuffer, ShouldDeserializeNullBytes) { + const NULLABLE_BYTES value = absl::nullopt; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesCapturingBuffer, ShouldThrowOnInvalidLength) { + // given + NullableBytesCapturingBuffer testee; + Buffer::OwnedImpl buffer; + + const INT32 bytes_length = -2; // -1 is OK for NULLABLE_BYTES + encoder.encode(bytes_length, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +// === ARRAY BUFFER ============================================================ + +TEST(ArrayBuffer, ShouldConsumeCorrectAmountOfData) { + const NULLABLE_ARRAY value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + serializeThenDeserializeAndCheckEquality>(value); +} + +TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { + // given + ArrayBuffer testee; + Buffer::OwnedImpl buffer; + + const INT32 len = -2; // -1 is OK for ARRAY + encoder.encode(len, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +// === COMPOSITE BUFFER ======================================================== + +struct CompositeBufferResult { + STRING field1_; + NULLABLE_ARRAY field2_; + INT16 field3_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + return written; + } +}; + +bool operator==(const CompositeBufferResult& lhs, const CompositeBufferResult& rhs) { + return (lhs.field1_ == rhs.field1_) && (lhs.field2_ == rhs.field2_) && + (lhs.field3_ == rhs.field3_); +} + +typedef CompositeBuffer, + Int16Buffer> + TestCompositeBuffer; + +TEST(CompositeBuffer, ShouldDeserialize) { + const CompositeBufferResult expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234}; + serializeThenDeserializeAndCheckEquality(expected); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy From aa51c6b5d436180742822ee2c1335355576bc182 Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Sat, 3 Nov 2018 03:19:43 +0000 Subject: [PATCH 02/29] Fix compile error after rebases Signed-off-by: Adam Kotwasinski --- test/extensions/filters/network/kafka/kafka_request_test.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index c57361c5edbce..ce481ff71c38f 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -70,7 +70,7 @@ class BufferBasedTest : public testing::Test { return reinterpret_cast((slices[0]).mem_); } -private: +protected: Buffer::OwnedImpl buffer_; EncodingContext encoder_{-1}; }; From 074f8687b3933d256af1069e88eae5c24a1e032d Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Wed, 14 Nov 2018 00:00:00 +0000 Subject: [PATCH 03/29] Remove all request types except OffsetCommit v0..v1 (for review) Signed-off-by: Adam Kotwasinski --- .../filters/network/kafka/kafka_request.cc | 29 - .../filters/network/kafka/kafka_request.h | 918 ------------------ .../filters/network/kafka/serialization.h | 89 -- .../network/kafka/kafka_request_test.cc | 8 +- .../network/kafka/request_codec_test.cc | 430 +------- 5 files changed, 6 insertions(+), 1468 deletions(-) diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc index e4ec7b0a8d25f..d308cf80514fb 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -34,38 +34,9 @@ GeneratorMap computeGeneratorMap(std::vector specs) { } const RequestParserResolver RequestParserResolver::KAFKA_0_11{{ - ParserSpec{RequestType::Produce, - {0, 1, 2}, - [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); - }}, - ParserSpec{RequestType::Produce, - {3}, - [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); - }}, - PARSER_SPEC(Fetch, V0, 0, 1, 2), - PARSER_SPEC(Fetch, V3, 3), - PARSER_SPEC(Fetch, V4, 4), - PARSER_SPEC(Fetch, V5, 5), - PARSER_SPEC(ListOffsets, V0, 0), - PARSER_SPEC(ListOffsets, V1, 1), - PARSER_SPEC(ListOffsets, V2, 2), - PARSER_SPEC(Metadata, V0, 0, 1, 2, 3), - PARSER_SPEC(Metadata, V4, 4), - PARSER_SPEC(LeaderAndIsr, V0, 0), - PARSER_SPEC(StopReplica, V0, 0), - PARSER_SPEC(UpdateMetadata, V0, 0), - PARSER_SPEC(UpdateMetadata, V1, 1), - PARSER_SPEC(UpdateMetadata, V2, 2), - PARSER_SPEC(UpdateMetadata, V3, 3), - PARSER_SPEC(ControlledShutdown, V1, 1), PARSER_SPEC(OffsetCommit, V0, 0), PARSER_SPEC(OffsetCommit, V1, 1), - PARSER_SPEC(OffsetCommit, V2, 2, 3), - PARSER_SPEC(OffsetFetch, V0, 0, 1, 2, 3), // XXX(adam.kotwasinski) missing request types here - PARSER_SPEC(ApiVersions, V0, 0, 1), }}; ParserSharedPtr RequestParserResolver::createParser(INT16 api_key, INT16 api_version, diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index c7b8754c36342..6727dd2a84204 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -232,837 +232,6 @@ class Request : public Message { RequestHeader request_header_; }; -// === PRODUCE (0) ============================================================= - -/** - * Produce request parser is a special case that has two corresponding parsers - * One parser captures data, the other one does not, only saving the length of data provided - * This might be used in filters that do not need access to data (e.g. only want to update request - * type metrics) - */ - -// holds data sent by client -struct FatProducePartition { - const INT32 partition_; - const NULLABLE_BYTES data_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(data_, dst); - return written; - } - - bool operator==(const FatProducePartition& rhs) const { - return partition_ == rhs.partition_ && data_ == rhs.data_; - }; - - friend std::ostream& operator<<(std::ostream& os, const FatProducePartition& arg) { - os << "{partition=" << arg.partition_ << ", data(size)="; - if (arg.data_.has_value()) { - os << arg.data_->size(); - } else { - os << ""; - } - return os << "}"; - } -}; - -// does not carry data, only its length -struct ThinProducePartition { - const INT32 partition_; - const INT32 data_size_; - - size_t encode(Buffer::Instance&, EncodingContext&) const { - throw EnvoyException("ThinProducePartition cannot be encoded"); - } - - bool operator==(const ThinProducePartition& rhs) const { - return partition_ == rhs.partition_ && data_size_ == rhs.data_size_; - }; - - friend std::ostream& operator<<(std::ostream& os, const ThinProducePartition& arg) { - return os << "{partition=" << arg.partition_ << ", data_size=" << arg.data_size_ << "}"; - } -}; - -template struct ProduceTopic { - const STRING topic_; - const NULLABLE_ARRAY partitions_; - - bool operator==(const ProduceTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - friend std::ostream& operator<<(std::ostream& os, const ProduceTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; - } -}; - -typedef ProduceTopic FatProduceTopic; -typedef ProduceTopic ThinProduceTopic; - -/** - * PT carries partition type, which can be capturing (contains bytes) or non-capturing (contains - * bytes' length only) - */ -template class ProduceRequest : public Request { -public: - // v0 .. v2 - ProduceRequest(INT16 acks, INT32 timeout, NULLABLE_ARRAY> topics) - : ProduceRequest(absl::nullopt, acks, timeout, topics){}; - - // v3 - ProduceRequest(NULLABLE_STRING transactional_id, INT16 acks, INT32 timeout, - NULLABLE_ARRAY> topics) - : Request{RequestType::Produce}, - transactional_id_{transactional_id}, acks_{acks}, timeout_{timeout}, topics_{topics} {}; - - bool operator==(const ProduceRequest& rhs) const { - return request_header_ == rhs.request_header_ && transactional_id_ == rhs.transactional_id_ && - acks_ == rhs.acks_ && timeout_ == rhs.timeout_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - if (request_header_.api_version_ >= 3) { - written += encoder.encode(transactional_id_, dst); - } - written += encoder.encode(acks_, dst); - written += encoder.encode(timeout_, dst); - written += encoder.encode(topics_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{transactional_id=" << transactional_id_ << ", acks=" << acks_ - << ", timeout=" << timeout_ << ", topics=" << topics_ << "}"; - }; - -private: - const NULLABLE_STRING transactional_id_; - const INT16 acks_; - const INT32 timeout_; - const NULLABLE_ARRAY> topics_; -}; - -typedef ProduceRequest FatProduceRequest; -typedef ProduceRequest ThinProduceRequest; - -// clang-format off -class ThinProducePartitionArrayBuffer : public ArrayBuffer> {}; -class ProducePartitionArrayBuffer : public ArrayBuffer> {}; - -class ThinProduceTopicArrayBuffer : public ArrayBuffer> {}; -class ProduceTopicArrayBuffer : public ArrayBuffer> {}; - -class ThinProduceRequestV0Buffer : public CompositeBuffer {}; -class ThinProduceRequestV3Buffer : public CompositeBuffer {}; -class FatProduceRequestV0Buffer : public CompositeBuffer {}; -class FatProduceRequestV3Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(ThinProduceRequest, V0); -DEFINE_REQUEST_PARSER(ThinProduceRequest, V3); -DEFINE_REQUEST_PARSER(FatProduceRequest, V0); -DEFINE_REQUEST_PARSER(FatProduceRequest, V3); -// clang-format on - -// === FETCH (1) =============================================================== - -struct FetchRequestPartition { - const INT32 partition_; - const INT64 fetch_offset_; - const INT64 log_start_offset_; // since v5 - const INT32 max_bytes_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(fetch_offset_, dst); - if (encoder.apiVersion() >= 5) { - written += encoder.encode(log_start_offset_, dst); - } - written += encoder.encode(max_bytes_, dst); - return written; - } - - friend std::ostream& operator<<(std::ostream& os, const FetchRequestPartition& arg) { - return os << "{partition=" << arg.partition_ << ", fetch_offset=" << arg.fetch_offset_ - << ", log_start_offset=" << arg.log_start_offset_ << ", max_bytes=" << arg.max_bytes_ - << "}"; - } - - bool operator==(const FetchRequestPartition& rhs) const { - return partition_ == rhs.partition_ && fetch_offset_ == rhs.fetch_offset_ && - log_start_offset_ == rhs.log_start_offset_ && max_bytes_ == rhs.max_bytes_; - }; - - // v0 .. v4 - FetchRequestPartition(INT32 partition, INT64 fetch_offset, INT32 max_bytes) - : FetchRequestPartition(partition, fetch_offset, -1, max_bytes){}; - - // v5 - FetchRequestPartition(INT32 partition, INT64 fetch_offset, INT64 log_start_offset, - INT32 max_bytes) - : partition_{partition}, fetch_offset_{fetch_offset}, log_start_offset_{log_start_offset}, - max_bytes_{max_bytes} {}; -}; - -struct FetchRequestTopic { - const STRING topic_; - const NULLABLE_ARRAY partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const FetchRequestTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; - - friend std::ostream& operator<<(std::ostream& os, const FetchRequestTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; - } -}; - -class FetchRequest : public Request { -public: - // v0 .. v2 - FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, - NULLABLE_ARRAY topics) - : FetchRequest(replica_id, max_wait_time, min_bytes, -1, topics){}; - - // v3 - FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, INT32 max_bytes, - NULLABLE_ARRAY topics) - : FetchRequest(replica_id, max_wait_time, min_bytes, max_bytes, -1, topics){}; - - // v4 .. v5 - FetchRequest(INT32 replica_id, INT32 max_wait_time, INT32 min_bytes, INT32 max_bytes, - INT8 isolation_level, NULLABLE_ARRAY topics) - : Request{RequestType::Fetch}, replica_id_{replica_id}, max_wait_time_{max_wait_time}, - min_bytes_{min_bytes}, max_bytes_{max_bytes}, - isolation_level_{isolation_level}, topics_{topics} {}; - - bool operator==(const FetchRequest& rhs) const { - return request_header_ == rhs.request_header_ && replica_id_ == rhs.replica_id_ && - max_wait_time_ == rhs.max_wait_time_ && min_bytes_ == rhs.min_bytes_ && - max_bytes_ == rhs.max_bytes_ && isolation_level_ == rhs.isolation_level_ && - topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - INT16 api_version = request_header_.api_version_; - written += encoder.encode(replica_id_, dst); - written += encoder.encode(max_wait_time_, dst); - written += encoder.encode(min_bytes_, dst); - if (api_version >= 3) { - written += encoder.encode(max_bytes_, dst); - } - if (api_version >= 4) { - written += encoder.encode(isolation_level_, dst); - } - written += encoder.encode(topics_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{replica_id=" << replica_id_ << ", max_wait_time=" << max_wait_time_ - << ", min_bytes=" << min_bytes_ << ", max_bytes=" << max_bytes_ - << ", isolation_level=" << static_cast(isolation_level_) - << ", topics=" << topics_ << "}"; - } - -private: - const INT32 replica_id_; - const INT32 max_wait_time_; - const INT32 min_bytes_; - const INT32 max_bytes_; // since v3 - const INT8 isolation_level_; // since v4 - const NULLABLE_ARRAY topics_; -}; - -// clang-format off -class FetchRequestPartitionV0Buffer : public CompositeBuffer {}; -class FetchRequestPartitionV0ArrayBuffer : public ArrayBuffer {}; -class FetchRequestTopicV0Buffer : public CompositeBuffer {}; -class FetchRequestTopicV0ArrayBuffer : public ArrayBuffer {}; - -class FetchRequestPartitionV5Buffer : public CompositeBuffer {}; -class FetchRequestPartitionV5ArrayBuffer : public ArrayBuffer {}; -class FetchRequestTopicV5Buffer : public CompositeBuffer {}; -class FetchRequestTopicV5ArrayBuffer : public ArrayBuffer {}; - -class FetchRequestV0Buffer : public CompositeBuffer {}; -class FetchRequestV3Buffer : public CompositeBuffer {}; -class FetchRequestV4Buffer : public CompositeBuffer {}; -class FetchRequestV5Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(FetchRequest, V0); -DEFINE_REQUEST_PARSER(FetchRequest, V3); -DEFINE_REQUEST_PARSER(FetchRequest, V4); -DEFINE_REQUEST_PARSER(FetchRequest, V5); -// clang-format on - -// === LIST OFFSETS (2) ======================================================== - -struct ListOffsetsPartition { - const INT32 partition_; - const INT64 timestamp_; - const INT32 max_num_offsets_; // only v0 - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(timestamp_, dst); - if (encoder.apiVersion() == 0) { - written += encoder.encode(max_num_offsets_, dst); - } - return written; - } - - bool operator==(const ListOffsetsPartition& rhs) const { - return partition_ == rhs.partition_ && timestamp_ == rhs.timestamp_ && - max_num_offsets_ == rhs.max_num_offsets_; - }; - - friend std::ostream& operator<<(std::ostream& os, const ListOffsetsPartition& arg) { - return os << "{partition=" << arg.partition_ << ", timestamp=" << arg.timestamp_ - << ", max_num_offsets=" << arg.max_num_offsets_ << "}"; - } - - // v0 - ListOffsetsPartition(INT32 partition, INT64 timestamp, INT32 max_num_offsets) - : partition_{partition}, timestamp_{timestamp}, max_num_offsets_{max_num_offsets} {}; - - // v1 .. v2 - ListOffsetsPartition(INT32 partition, INT64 timestamp) - : ListOffsetsPartition(partition, timestamp, -1){}; -}; - -struct ListOffsetsTopic { - const STRING topic_; - const NULLABLE_ARRAY partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const ListOffsetsTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; - - friend std::ostream& operator<<(std::ostream& os, const ListOffsetsTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; - } -}; - -class ListOffsetsRequest : public Request { -public: - // v0 .. v1 - ListOffsetsRequest(INT32 replica_id, NULLABLE_ARRAY topics) - : ListOffsetsRequest(replica_id, -1, topics){}; - - // v2 - ListOffsetsRequest(INT32 replica_id, INT8 isolation_level, - NULLABLE_ARRAY topics) - : Request{RequestType::ListOffsets}, replica_id_{replica_id}, - isolation_level_{isolation_level}, topics_{topics} {}; - - bool operator==(const ListOffsetsRequest& rhs) const { - return request_header_ == rhs.request_header_ && replica_id_ == rhs.replica_id_ && - isolation_level_ == rhs.isolation_level_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(replica_id_, dst); - if (encoder.apiVersion() >= 2) { - written += encoder.encode(isolation_level_, dst); - } - written += encoder.encode(topics_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{replica_id=" << replica_id_ - << ", isolation_level=" << static_cast(isolation_level_) - << ", topics=" << topics_ << "}"; - } - -private: - const INT32 replica_id_; - const INT8 isolation_level_; // since v2 - const NULLABLE_ARRAY topics_; -}; - -// clang-format off -class ListOffsetsPartitionV0Buffer : public CompositeBuffer {}; -class ListOffsetsPartitionV0ArrayBuffer : public ArrayBuffer {}; -class ListOffsetsTopicV0Buffer : public CompositeBuffer {}; -class ListOffsetsTopicV0ArrayBuffer : public ArrayBuffer {}; - -class ListOffsetsPartitionV1Buffer : public CompositeBuffer {}; -class ListOffsetsPartitionV1ArrayBuffer : public ArrayBuffer {}; -class ListOffsetsTopicV1Buffer : public CompositeBuffer {}; -class ListOffsetsTopicV1ArrayBuffer : public ArrayBuffer {}; - -class ListOffsetsRequestV0Buffer : public CompositeBuffer {}; -class ListOffsetsRequestV1Buffer : public CompositeBuffer {}; -class ListOffsetsRequestV2Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(ListOffsetsRequest, V0); -DEFINE_REQUEST_PARSER(ListOffsetsRequest, V1); -DEFINE_REQUEST_PARSER(ListOffsetsRequest, V2); -// clang-format on - -// === METADATA (3) ============================================================ - -class MetadataRequest : public Request { -public: - // v0 .. v3 - MetadataRequest(NULLABLE_ARRAY topics) : MetadataRequest(topics, false){}; - - // v4 - MetadataRequest(NULLABLE_ARRAY topics, BOOLEAN allow_auto_topic_creation) - : Request{RequestType::Metadata}, topics_{topics}, allow_auto_topic_creation_{ - allow_auto_topic_creation} {}; - - bool operator==(const MetadataRequest& rhs) const { - return request_header_ == rhs.request_header_ && topics_ == rhs.topics_ && - allow_auto_topic_creation_ == rhs.allow_auto_topic_creation_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(topics_, dst); - if (encoder.apiVersion() >= 2) { - written += encoder.encode(allow_auto_topic_creation_, dst); - } - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{topics=" << topics_ - << ", allow_auto_topic_creation=" << allow_auto_topic_creation_ << "}"; - } - -private: - NULLABLE_ARRAY topics_; - BOOLEAN allow_auto_topic_creation_; // since v4 -}; - -// clang-format off -class MetadataRequestTopicV0Buffer : public ArrayBuffer {}; -class MetadataRequestV0Buffer : public CompositeBuffer {}; -class MetadataRequestV4Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(MetadataRequest, V0); -DEFINE_REQUEST_PARSER(MetadataRequest, V4); -// clang-format on - -// === LEADER-AND-ISR (4) ====================================================== - -/** - * This structure is used in both LeaderAndIsr v0 & UpdateMetadata - */ - -struct MetadataPartitionState { - const STRING topic_; - const INT32 partition_; - const INT32 controller_epoch_; - const INT32 leader_; - const INT32 leader_epoch_; - const NULLABLE_ARRAY isr_; - const INT32 zk_version_; - const NULLABLE_ARRAY replicas_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partition_, dst); - written += encoder.encode(controller_epoch_, dst); - written += encoder.encode(leader_, dst); - written += encoder.encode(leader_epoch_, dst); - written += encoder.encode(isr_, dst); - written += encoder.encode(zk_version_, dst); - written += encoder.encode(replicas_, dst); - return written; - } - - bool operator==(const MetadataPartitionState& rhs) const { - return topic_ == rhs.topic_ && partition_ == rhs.partition_ && - controller_epoch_ == rhs.controller_epoch_ && leader_ == rhs.leader_ && - leader_epoch_ == rhs.leader_epoch_ && isr_ == rhs.isr_ && - zk_version_ == rhs.zk_version_ && replicas_ == rhs.replicas_; - }; - - friend std::ostream& operator<<(std::ostream& os, const MetadataPartitionState& arg) { - return os << "{topic=" << arg.topic_ << ", partition=" << arg.partition_ - << ", controller_epoch=" << arg.controller_epoch_ << ", leader=" << arg.leader_ - << ", leader_epoch=" << arg.leader_epoch_ << ", isr=" << arg.isr_ - << ", zk_version=" << arg.zk_version_ << ", zk_version=" << arg.zk_version_ << "}"; - } -}; - -struct LeaderAndIsrLiveLeader { - const INT32 id_; - const STRING host_; - const INT32 port_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(id_, dst); - written += encoder.encode(host_, dst); - written += encoder.encode(port_, dst); - return written; - } - - bool operator==(const LeaderAndIsrLiveLeader& rhs) const { - return id_ == rhs.id_ && host_ == rhs.host_ && port_ == rhs.port_; - }; - - friend std::ostream& operator<<(std::ostream& os, const LeaderAndIsrLiveLeader& arg) { - return os << "{id=" << arg.id_ << ", host=" << arg.host_ << ", port=" << arg.port_ << "}"; - } -}; - -class LeaderAndIsrRequest : public Request { -public: - // v0 - LeaderAndIsrRequest(INT32 controller_id, INT32 controller_epoch, - NULLABLE_ARRAY partition_states, - NULLABLE_ARRAY live_readers) - : Request{RequestType::LeaderAndIsr}, controller_id_{controller_id}, - controller_epoch_{controller_epoch}, partition_states_{partition_states}, - live_readers_{live_readers} {}; - - bool operator==(const LeaderAndIsrRequest& rhs) const { - return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && - controller_epoch_ == rhs.controller_epoch_ && - partition_states_ == rhs.partition_states_ && live_readers_ == rhs.live_readers_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(controller_id_, dst); - written += encoder.encode(controller_epoch_, dst); - written += encoder.encode(partition_states_, dst); - written += encoder.encode(live_readers_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ - << ", partition_states=" << partition_states_ << ", live_readers=" << live_readers_ - << "}"; - } - -private: - const INT32 controller_id_; - const INT32 controller_epoch_; - const NULLABLE_ARRAY partition_states_; - const NULLABLE_ARRAY live_readers_; -}; - -// clang-format off -class MetadataPartitionStateV0Buffer : public CompositeBuffer, - Int32Buffer, - ArrayBuffer - > {}; -class MetadataPartitionStateV0ArrayBuffer : public ArrayBuffer {}; - -class LeaderAndIsrLiveLeaderV0Buffer : public CompositeBuffer {}; -class LeaderAndIsrLiveLeaderV0ArrayBuffer : public ArrayBuffer {}; - -class LeaderAndIsrRequestV0Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(LeaderAndIsrRequest, V0); -// clang-format on - -// === STOP REPLICA (5) ======================================================== - -struct StopReplicaPartition { - const STRING topic_; - const INT32 partition_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partition_, dst); - return written; - } - - bool operator==(const StopReplicaPartition& rhs) const { - return topic_ == rhs.topic_ && partition_ == rhs.partition_; - }; - - friend std::ostream& operator<<(std::ostream& os, const StopReplicaPartition& arg) { - return os << "{topic=" << arg.topic_ << ", partition=" << arg.partition_ << "}"; - } -}; - -class StopReplicaRequest : public Request { -public: - // v0 - StopReplicaRequest(INT32 controller_id, INT32 controller_epoch, BOOLEAN delete_partitions, - NULLABLE_ARRAY partitions) - : Request{RequestType::StopReplica}, controller_id_{controller_id}, - controller_epoch_{controller_epoch}, delete_partitions_{delete_partitions}, - partitions_{partitions} {}; - - bool operator==(const StopReplicaRequest& rhs) const { - return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && - controller_epoch_ == rhs.controller_epoch_ && - delete_partitions_ == rhs.delete_partitions_ && partitions_ == rhs.partitions_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(controller_id_, dst); - written += encoder.encode(controller_epoch_, dst); - written += encoder.encode(delete_partitions_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ - << ", delete_partitions=" << delete_partitions_ << ", partitions=" << partitions_ - << "}"; - } - -private: - const INT32 controller_id_; - const INT32 controller_epoch_; - const BOOLEAN delete_partitions_; - const NULLABLE_ARRAY partitions_; -}; - -// clang-format off -class StopReplicaPartitionV0Buffer : public CompositeBuffer {}; -class StopReplicaPartitionV0ArrayBuffer : public ArrayBuffer {}; - -class StopReplicaRequestV0Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(StopReplicaRequest, V0); -// clang-format on - -// === UPDATE METADATA (6) ===================================================== - -// uses MetadataPartitionState from LeaderAndIsr - -struct UpdateMetadataLiveBrokerEndpoint { - const INT32 port_; - const STRING host_; - const STRING listener_name_; - const INT16 security_protocol_type_; - - // v1 .. v2 - UpdateMetadataLiveBrokerEndpoint(INT32 port, STRING host, INT16 security_protocol_type) - : UpdateMetadataLiveBrokerEndpoint{port, host, "", security_protocol_type} {}; - - // v3 - UpdateMetadataLiveBrokerEndpoint(INT32 port, STRING host, STRING listener_name, - INT16 security_protocol_type) - : port_{port}, host_{host}, listener_name_{listener_name}, security_protocol_type_{ - security_protocol_type} {}; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(port_, dst); - written += encoder.encode(host_, dst); - if (encoder.apiVersion() >= 3) { - written += encoder.encode(listener_name_, dst); - } - written += encoder.encode(security_protocol_type_, dst); - return written; - } - - bool operator==(const UpdateMetadataLiveBrokerEndpoint& rhs) const { - return port_ == rhs.port_ && host_ == rhs.host_ && listener_name_ == rhs.listener_name_ && - security_protocol_type_ == rhs.security_protocol_type_; - }; - - friend std::ostream& operator<<(std::ostream& os, const UpdateMetadataLiveBrokerEndpoint& arg) { - return os << "{port=" << arg.port_ << ", host=" << arg.host_ - << ", listener_name=" << arg.listener_name_ - << ", security_protocol_type=" << arg.security_protocol_type_ << "}"; - } -}; - -struct UpdateMetadataLiveBroker { - const INT32 id_; - const NULLABLE_ARRAY endpoints_; - const NULLABLE_STRING rack_; // since v2 - - // v0 - // instead of having dedicated fields, store data as single UpdateMetadataLiveBrokerEndpoint (java - // client does it as well) - UpdateMetadataLiveBroker(INT32 id, STRING host, INT32 port) - : UpdateMetadataLiveBroker(id, {{UpdateMetadataLiveBrokerEndpoint{port, host, 0}}}){}; - - // v1 - UpdateMetadataLiveBroker(INT32 id, NULLABLE_ARRAY endpoints) - : UpdateMetadataLiveBroker{id, endpoints, absl::nullopt} {}; - - // v2 - UpdateMetadataLiveBroker(INT32 id, NULLABLE_ARRAY endpoints, - NULLABLE_STRING rack) - : id_{id}, endpoints_{endpoints}, rack_{rack} {}; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(id_, dst); - if (encoder.apiVersion() == 0) { - // we stored host+port as endpoint, but need to serialize properly - const UpdateMetadataLiveBrokerEndpoint& only_endpoint = (*endpoints_)[0]; - written += encoder.encode(only_endpoint.host_, dst); - written += encoder.encode(only_endpoint.port_, dst); - } else { - written += encoder.encode(endpoints_, dst); - if (encoder.apiVersion() >= 2) { - written += encoder.encode(rack_, dst); - } - } - return written; - } - - bool operator==(const UpdateMetadataLiveBroker& rhs) const { - return id_ == rhs.id_ && endpoints_ == rhs.endpoints_ && rack_ == rhs.rack_; - }; - - friend std::ostream& operator<<(std::ostream& os, const UpdateMetadataLiveBroker& arg) { - return os << "{id=" << arg.id_ << ", endpoints=" << arg.endpoints_ << ", rack=" << arg.rack_ - << "}"; - } -}; - -class UpdateMetadataRequest : public Request { -public: - // v0 - UpdateMetadataRequest(INT32 controller_id, INT32 controller_epoch, - NULLABLE_ARRAY partition_states, - NULLABLE_ARRAY live_brokers) - : Request{RequestType::UpdateMetadata}, controller_id_{controller_id}, - controller_epoch_{controller_epoch}, partition_states_{partition_states}, - live_brokers_{live_brokers} {}; - - bool operator==(const UpdateMetadataRequest& rhs) const { - return request_header_ == rhs.request_header_ && controller_id_ == rhs.controller_id_ && - controller_epoch_ == rhs.controller_epoch_ && - partition_states_ == rhs.partition_states_ && live_brokers_ == rhs.live_brokers_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(controller_id_, dst); - written += encoder.encode(controller_epoch_, dst); - written += encoder.encode(partition_states_, dst); - written += encoder.encode(live_brokers_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{controller_id=" << controller_id_ << ", controller_epoch=" << controller_epoch_ - << ", partition_states=" << partition_states_ << ", live_brokers=" << live_brokers_ - << "}"; - } - -private: - const INT32 controller_id_; - const INT32 controller_epoch_; - const NULLABLE_ARRAY partition_states_; - const NULLABLE_ARRAY live_brokers_; -}; - -// clang-format off -class UpdateMetadataLiveBrokerEndpointV1Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerEndpointV1ArrayBuffer : public ArrayBuffer {}; -class UpdateMetadataLiveBrokerEndpointV3Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerEndpointV3ArrayBuffer : public ArrayBuffer {}; - -class UpdateMetadataLiveBrokerV0Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerV0ArrayBuffer : public ArrayBuffer {}; -class UpdateMetadataLiveBrokerV1Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerV1ArrayBuffer : public ArrayBuffer {}; -class UpdateMetadataLiveBrokerV2Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerV2ArrayBuffer : public ArrayBuffer {}; -class UpdateMetadataLiveBrokerV3Buffer : public CompositeBuffer {}; -class UpdateMetadataLiveBrokerV3ArrayBuffer : public ArrayBuffer {}; - -class UpdateMetadataRequestV0Buffer : public CompositeBuffer {}; -class UpdateMetadataRequestV1Buffer : public CompositeBuffer {}; -class UpdateMetadataRequestV2Buffer : public CompositeBuffer {}; -class UpdateMetadataRequestV3Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V0); -DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V1); -DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V2); -DEFINE_REQUEST_PARSER(UpdateMetadataRequest, V3); -// clang-format on - -// === CONTROLLED SHUTDOWN (7) ================================================= - -// v0 is not documented -class ControlledShutdownRequest : public Request { -public: - // v1 - ControlledShutdownRequest(INT32 broker_id) - : Request{RequestType::ControlledShutdown}, broker_id_{broker_id} {}; - - bool operator==(const ControlledShutdownRequest& rhs) const { - return request_header_ == rhs.request_header_ && broker_id_ == rhs.broker_id_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(broker_id_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{broker_id=" << broker_id_ << "}"; - } - -private: - const INT32 broker_id_; -}; - -// clang-format off -class ControlledShutdownRequestV1Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(ControlledShutdownRequest, V1); -// clang-format on - // === OFFSET COMMIT (8) ======================================================= struct OffsetCommitPartition { @@ -1185,98 +354,11 @@ class OffsetCommitPartitionV1ArrayBuffer : public ArrayBuffer {}; class OffsetCommitTopicV1ArrayBuffer : public ArrayBuffer {}; -class OffsetCommitTopicV2ArrayBuffer : public OffsetCommitTopicV0ArrayBuffer {}; // v2 partition format is the same as v0 - class OffsetCommitRequestV0Buffer : public CompositeBuffer {}; class OffsetCommitRequestV1Buffer : public CompositeBuffer {}; -class OffsetCommitRequestV2Buffer : public CompositeBuffer {}; DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); -DEFINE_REQUEST_PARSER(OffsetCommitRequest, V2); -// clang-format on - -// === OFFSET FETCH (9) ======================================================== - -struct OffsetFetchTopic { - const STRING topic_; - const NULLABLE_ARRAY partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetFetchTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; - - friend std::ostream& operator<<(std::ostream& os, const OffsetFetchTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions=" << arg.partitions_ << "}"; - } -}; - -class OffsetFetchRequest : public Request { -public: - // v0 .. v3 - OffsetFetchRequest(STRING group_id, NULLABLE_ARRAY topics) - : Request{RequestType::OffsetFetch}, group_id_{group_id}, topics_{topics} {}; - - bool operator==(const OffsetFetchRequest& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && - topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - written += encoder.encode(topics_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{group_id=" << group_id_ << ", topics=" << topics_ << "}"; - } - -private: - const STRING group_id_; - const NULLABLE_ARRAY topics_; -}; - -// clang-format off -class OffsetFetchPartitionV0ArrayBuffer : public ArrayBuffer {}; -class OffsetFetchTopicV0Buffer : public CompositeBuffer {}; -class OffsetFetchTopicV0ArrayBuffer : public ArrayBuffer {}; - -class OffsetFetchRequestV0Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(OffsetFetchRequest, V0); -// clang-format on - -// === API VERSIONS (18) ======================================================= - -class ApiVersionsRequest : public Request { -public: - // v0 .. v1 - ApiVersionsRequest() : Request{RequestType::ApiVersions} {}; - - bool operator==(const ApiVersionsRequest& rhs) const { - return request_header_ == rhs.request_header_; - }; - -protected: - size_t encodeDetails(Buffer::Instance&, EncodingContext&) const override { return 0; } - - std::ostream& printDetails(std::ostream& os) const override { return os << "{}"; } -}; - -// clang-format off -class ApiVersionsRequestV0Buffer : public NullBuffer {}; - -DEFINE_REQUEST_PARSER(ApiVersionsRequest, V0); // clang-format on // === UNKNOWN REQUEST ========================================================= diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index afc748cb7bbe0..a17ea97ba25b7 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -483,95 +483,6 @@ class CompositeBuffer : public Deserializer { T4 buffer4_; }; -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - consumed += buffer3_.feed(buffer, remaining); - consumed += buffer4_.feed(buffer, remaining); - consumed += buffer5_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer5_.ready(); } - RT get() const { - return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get(), buffer5_.get()}; - } - -protected: - T1 buffer1_; - T2 buffer2_; - T3 buffer3_; - T4 buffer4_; - T5 buffer5_; -}; - -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - consumed += buffer3_.feed(buffer, remaining); - consumed += buffer4_.feed(buffer, remaining); - consumed += buffer5_.feed(buffer, remaining); - consumed += buffer6_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer6_.ready(); } - RT get() const { - return {buffer1_.get(), buffer2_.get(), buffer3_.get(), - buffer4_.get(), buffer5_.get(), buffer6_.get()}; - } - -protected: - T1 buffer1_; - T2 buffer2_; - T3 buffer3_; - T4 buffer4_; - T5 buffer5_; - T6 buffer6_; -}; - -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - consumed += buffer3_.feed(buffer, remaining); - consumed += buffer4_.feed(buffer, remaining); - consumed += buffer5_.feed(buffer, remaining); - consumed += buffer6_.feed(buffer, remaining); - consumed += buffer7_.feed(buffer, remaining); - consumed += buffer8_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer8_.ready(); } - RT get() const { - return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get(), - buffer5_.get(), buffer6_.get(), buffer7_.get(), buffer8_.get()}; - } - -protected: - T1 buffer1_; - T2 buffer2_; - T3 buffer3_; - T4 buffer4_; - T5 buffer5_; - T6 buffer6_; - T7 buffer7_; - T8 buffer8_; -}; - // === ARRAY BUFFER ============================================================ /** diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index ce481ff71c38f..056d116b7715b 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -29,7 +29,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestTypeIsNotRegistered) { TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) { // given GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); + return std::make_shared(arg); }; RequestParserResolver testee{{{0, {0, 1}, generator}}}; RequestContextSharedPtr context{new RequestContext{}}; @@ -46,7 +46,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { // given GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); + return std::make_shared(arg); }; RequestParserResolver testee{{{0, {0, 1, 2, 3}, generator}}}; RequestContextSharedPtr context{new RequestContext{}}; @@ -56,7 +56,7 @@ TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { // then ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); } class BufferBasedTest : public testing::Test { @@ -104,7 +104,7 @@ class MockRequestParserResolver : public RequestParserResolver { TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { // given const MockRequestParserResolver parser_resolver; - const ParserSharedPtr parser{new ApiVersionsRequestV0Parser{nullptr}}; + const ParserSharedPtr parser{new OffsetCommitRequestV0Parser{nullptr}}; EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); const INT32 request_len = 1000; diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc index 3e834ca76751c..ab1a1798f3358 100644 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -39,379 +39,6 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria return std::dynamic_pointer_cast(receivedMessage); }; -// === PRODUCE (0) ============================================================= - -TEST_F(RequestDecoderTest, shouldParseProduceRequestV0toV2) { - // given - NULLABLE_ARRAY topics{ - {{"t1", {{{0, NULLABLE_BYTES(100)}, {1, NULLABLE_BYTES(200)}}}}, - {"t2", {{{0, NULLABLE_BYTES(300)}}}}}}; - FatProduceRequest request{10, 20, topics}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseProduceRequestV3) { - // given - NULLABLE_ARRAY topics{ - {{"t1", {{{0, NULLABLE_BYTES(100)}, {1, NULLABLE_BYTES(200)}}}}, - {"t2", {{{0, NULLABLE_BYTES(300)}}}}}}; - // transaction_id in V3 - FatProduceRequest request{"txid", 10, 20, topics}; - request.apiVersion() = 3; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === FETCH (1) =============================================================== - -TEST_F(RequestDecoderTest, shouldParseFetchRequestV0toV2) { - // given - FetchRequest request{1, - 1000, - 10, - {{ - {"topic1", {{{10, 20, 2000}}}}, - {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, - {"topic1", {{{13, 23, 2003}}}}, - }}}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseFetchRequestV3) { - // given - FetchRequest request{1, - 1000, - 10, - 20, // max_bytes in V3 - {{ - {"topic1", {{{10, 20, 2000}}}}, - {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, - {"topic1", {{{13, 23, 2003}}}}, - }}}; - request.apiVersion() = 3; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseFetchRequestV4) { - // given - FetchRequest request{1, - 1000, - 10, - 20, - 2, // isolation level in V4 - {{ - {"topic1", {{{10, 20, 2000}}}}, - {"topic1", {{{11, 21, 2001}, {12, 22, 2002}}}}, - {"topic1", {{{13, 23, 2003}}}}, - }}}; - request.apiVersion() = 4; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseFetchRequestV5) { - // given - FetchRequest request{1, - 1000, - 10, - 20, - 2, - {{ - // log_start_offset_ in partition data in V5 - {"topic1", {{{10, 20, 1000, 2000}}}}, - {"topic1", {{{11, 21, 1001, 2001}, {12, 22, 1002, 2002}}}}, - {"topic1", {{{13, 23, 1003, 2003}}}}, - }}}; - request.apiVersion() = 5; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === LIST OFFSETS (2) ======================================================== - -TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV0) { - // given - ListOffsetsRequest request{10, - {{ - // partition contains max_num_offsets in v0 only - {"topic1", {{{1, 1000, 10}, {2, 2000, 20}}}}, - {"topic2", {{{3, 3000, 30}}}}, - }}}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV1) { - // given - ListOffsetsRequest request{10, - {{ - // max_num_offsets removed in v1 - {"topic1", {{{1, 1000}, {2, 2000}}}}, - {"topic2", {{{3, 3000}}}}, - }}}; - request.apiVersion() = 1; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseListOffsetsRequestV2) { - // given - ListOffsetsRequest request{10, - 2, // isolation level in v2 - {{ - {"topic1", {{{1, 1000}, {2, 2000}}}}, - {"topic2", {{{3, 3000}}}}, - }}}; - request.apiVersion() = 2; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === METADATA (3) ============================================================ - -TEST_F(RequestDecoderTest, shouldParseMetadataRequestV0toV3) { - // given - MetadataRequest request{{{"t1", "t2", "t3"}}}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseMetadataRequestV4) { - // given - MetadataRequest request{{{"t1", "t2", "t3"}}, true}; - request.apiVersion() = 4; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === LEADER-AND-ISR (4) ====================================================== - -TEST_F(RequestDecoderTest, shouldParseLeaderAndIsrRequestV0) { - // given - MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; - MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; - NULLABLE_ARRAY partition_states{{ps1, ps2}}; - NULLABLE_ARRAY live_leaders{ - {{1, "host1", 9092}, {2, "host2", 9093}, {3, "host3", 9094}}}; - LeaderAndIsrRequest request{20, 1000, partition_states, live_leaders}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === STOP REPLICA (5) ======================================================== - -TEST_F(RequestDecoderTest, shouldParseStopReplicaRequestV0) { - // given - NULLABLE_ARRAY partitions{{{"t1", 0}, {"t2", 3}}}; - StopReplicaRequest request{10, 1000, true, partitions}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === UPDATE METADATA (6) ===================================================== - -TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV0) { - // given - MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; - MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; - NULLABLE_ARRAY partition_states{{ps1, ps2}}; - NULLABLE_ARRAY live_brokers{ - {{1, "host1", 9092}, {2, "host2", 9093}, {3, "host3", 9094}}}; - UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV1) { - // given - MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; - MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; - NULLABLE_ARRAY partition_states{{ps1, ps2}}; - NULLABLE_ARRAY live_brokers{ - {{1, {{{9092, "h1", 0}, {9093, "h1", 1}}}}, // endpoints added in v1, host & port removed - {2, {{{9092, "h2", 2}, {9093, "h2", 3}}}}}}; - UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; - request.apiVersion() = 1; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV2) { - // given - MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; - MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; - NULLABLE_ARRAY partition_states{{ps1, ps2}}; - NULLABLE_ARRAY live_brokers{ - {{1, {{{9092, "h1", 0}, {9093, "h1", 1}}}, "rack1"}, // rack added in v2 - {2, {{{9092, "h2", 2}, {9093, "h2", 3}}}, absl::nullopt}}}; - UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; - request.apiVersion() = 2; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseUpdateMetadataRequestV3) { - // given - MetadataPartitionState ps1{"t1", 0, 1000, 1, 2000, {{0, 1}}, 3000, {{0, 1, 2, 3, 4}}}; - MetadataPartitionState ps2{"t2", 1, 4000, 2, 5000, {{6, 7}}, 6000, {{6, 7, 8, 9, 10}}}; - NULLABLE_ARRAY partition_states{{ps1, ps2}}; - NULLABLE_ARRAY live_brokers{ - {{1, - {{{9092, "h1", "plain", 0}, {9093, "h1", "ssl", 1}}}, - "rack1"}, // listener_name added in v2 - {2, {{{9092, "h2", "name3", 2}, {9093, "h2", "name4", 3}}}, absl::nullopt}}}; - UpdateMetadataRequest request{20, 1000, partition_states, live_brokers}; - request.apiVersion() = 3; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === CONTROLLED SHUTDOWN (7) ================================================= - -TEST_F(RequestDecoderTest, shouldParseControlledShutdownRequestV1) { - // given - ControlledShutdownRequest request{1}; - request.apiVersion() = 1; - request.correlationId() = 42; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - // === OFFSET COMMIT (8) ======================================================= TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { @@ -451,66 +78,13 @@ TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { ASSERT_EQ(*received, request); } -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV2toV3) { - // given - NULLABLE_ARRAY topics{ - {{"topic1", {{{0, 10, "m1"}, {2, 20, "m2"}}}}, {"topic2", {{{3, 30, "m3"}}}}}}; - OffsetCommitRequest request{"group_id", 1234, "member", - 2345, // retention_time - topics}; - request.apiVersion() = 2; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === OFFSET FETCH (9) ======================================================== - -TEST_F(RequestDecoderTest, shouldParseOffsetFetchRequestV0toV3) { - // given - OffsetFetchRequest request{"group_id", {{{"topic1", {{0, 1, 2}}}, {"topic2", {{3, 4}}}}}}; - - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -// === API VERSIONS (18) ======================================================= - -TEST_F(RequestDecoderTest, shouldParseApiVersionsRequestV0toV1) { - // given - ApiVersionsRequest request{}; - request.apiVersion() = 0; - request.correlationId() = 42; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - // === UNKNOWN REQUEST ========================================================= TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given RequestEncoder serializer{buffer_}; - ApiVersionsRequest request{}; + NULLABLE_ARRAY topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; + OffsetCommitRequest request{"group_id", topics}; request.apiVersion() = 1; request.correlationId() = 42; request.clientId() = "client-id"; From f341df7d438b37d35474406bfa76c5cd5f4bdca3 Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Wed, 14 Nov 2018 00:21:05 +0000 Subject: [PATCH 04/29] Remove bytes buffers (for review - unused in these requests) Signed-off-by: Adam Kotwasinski --- .../filters/network/kafka/serialization.h | 136 ------------------ .../network/kafka/serialization_test.cc | 121 ---------------- 2 files changed, 257 deletions(-) diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index a17ea97ba25b7..cf1649c2749b3 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -264,142 +264,6 @@ class NullableStringBuffer : public Deserializer { bool ready_{false}; }; -// === BYTES BUFFERS =========================================================== - -/** - * Represents a raw sequence of bytes or null. - * For non-null values, first the length N is given as an INT32. Then N bytes follow. - * A null value is encoded with length of -1 and there are no following bytes. - */ - -/** - * This buffer ignores the data fed, the only result is the number of bytes ignored - */ -class NullableBytesIgnoringBuffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - const size_t length_consumed = length_buf_.feed(buffer, remaining); - if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer - return length_consumed; - } - - if (!length_consumed_) { - required_max_ = length_buf_.get(); - required_ = length_buf_.get(); - - if (required_ == NULL_BYTES_LENGTH) { - ready_ = true; - } - if (required_ < NULL_BYTES_LENGTH) { - throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); - } - - length_consumed_ = true; - } - - if (ready_) { - return length_consumed; - } - - const size_t data_consumed = std::min(required_, remaining); - required_ -= data_consumed; - - buffer += data_consumed; - remaining -= data_consumed; - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } - - bool ready() const { return ready_; } - - /** - * Returns length of ignored array, or -1 if that was null - */ - INT32 get() const { return required_max_; } - -private: - constexpr static INT32 NULL_BYTES_LENGTH{-1}; - - Int32Buffer length_buf_; - bool length_consumed_{false}; - INT32 required_max_; - INT32 required_; - bool ready_{false}; -}; - -/** - * This buffer captures the data fed - */ -class NullableBytesCapturingBuffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - const size_t length_consumed = length_buf_.feed(buffer, remaining); - if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer - return length_consumed; - } - - if (!length_consumed_) { - required_ = length_buf_.get(); - - if (required_ >= 0) { - data_buf_ = std::vector(required_); - } - if (required_ == NULL_BYTES_LENGTH) { - ready_ = true; - } - if (required_ < NULL_BYTES_LENGTH) { - throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); - } - - length_consumed_ = true; - } - - if (ready_) { - return length_consumed; - } - - const size_t data_consumed = std::min(required_, remaining); - const size_t written = data_buf_.size() - required_; - memcpy(data_buf_.data() + written, buffer, data_consumed); - required_ -= data_consumed; - - buffer += data_consumed; - remaining -= data_consumed; - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } - - bool ready() const { return ready_; } - - NULLABLE_BYTES get() const { - if (NULL_BYTES_LENGTH == required_) { - return absl::nullopt; - } else { - return {data_buf_}; - } - } - -private: - constexpr static INT32 NULL_BYTES_LENGTH{-1}; - - Int32Buffer length_buf_; - bool length_consumed_{false}; - INT32 required_; - - std::vector data_buf_; - bool ready_{false}; -}; - // === COMPOSITE BUFFER ======================================================== /** diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 27c0e627b5be4..e26cdc94946a2 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -29,7 +29,6 @@ TEST_EmptyBufferShouldNotBeReady(Int64Buffer); TEST_EmptyBufferShouldNotBeReady(BoolBuffer); TEST_EmptyBufferShouldNotBeReady(StringBuffer); TEST_EmptyBufferShouldNotBeReady(NullableStringBuffer); -TEST_EmptyBufferShouldNotBeReady(NullableBytesIgnoringBuffer); TEST(CompositeBuffer, EmptyBufferShouldNotBeReady) { // given const CompositeBuffer testee{}; @@ -226,126 +225,6 @@ TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -// === NULLABLE BYTES IGNORING BUFFER ========================================== - -TEST(NullableBytesIgnoringBuffer, ShouldDeserialize) { - // given - NullableBytesIgnoringBuffer testee; - - const INT32 bytes_to_ignore = 100; - NULLABLE_BYTES bytes = {std::vector(bytes_to_ignore)}; - - Buffer::OwnedImpl buffer; - size_t written = encoder.encode(bytes, buffer); - - uint64_t remaining = written * 10; - const uint64_t orig_remaining = remaining; - - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - const size_t consumed = testee.feed(data, remaining); - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), bytes_to_ignore); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); - - // when - 2 - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); -} - -TEST(NullableBytesIgnoringBuffer, ShouldDeserializeNullBytes) { - // given - NullableBytesIgnoringBuffer testee; - - const INT32 bytes_length = -1; - - Buffer::OwnedImpl buffer; - const size_t written = encoder.encode(bytes_length, buffer); - - uint64_t remaining = written * 10; - const uint64_t orig_remaining = remaining; - - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - const size_t consumed = testee.feed(data, remaining); - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), bytes_length); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); - - // when - 2 - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); -} - -TEST(NullableBytesIgnoringBuffer, ShouldThrowOnInvalidLength) { - // given - NullableBytesIgnoringBuffer testee; - Buffer::OwnedImpl buffer; - - const INT32 bytes_length = -2; // -1 is OK for NULLABLE_BYTES - encoder.encode(bytes_length, buffer); - - uint64_t remaining = 1024; - const char* data = getRawData(buffer); - - // when - // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); -} - -// === NULLABLE BYTES CAPTURING BUFFER ========================================= - -TEST(NullableBytesCapturingBuffer, ShouldDeserialize) { - const NULLABLE_BYTES value{{'a', 'b', 'c', 'd'}}; - serializeThenDeserializeAndCheckEquality(value); -} - -TEST(NullableBytesCapturingBuffer, ShouldDeserializeEmptyBytes) { - const NULLABLE_BYTES value{{}}; - serializeThenDeserializeAndCheckEquality(value); -} - -TEST(NullableBytesCapturingBuffer, ShouldDeserializeNullBytes) { - const NULLABLE_BYTES value = absl::nullopt; - serializeThenDeserializeAndCheckEquality(value); -} - -TEST(NullableBytesCapturingBuffer, ShouldThrowOnInvalidLength) { - // given - NullableBytesCapturingBuffer testee; - Buffer::OwnedImpl buffer; - - const INT32 bytes_length = -2; // -1 is OK for NULLABLE_BYTES - encoder.encode(bytes_length, buffer); - - uint64_t remaining = 1024; - const char* data = getRawData(buffer); - - // when - // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); -} - // === ARRAY BUFFER ============================================================ TEST(ArrayBuffer, ShouldConsumeCorrectAmountOfData) { From 12ece4f9bf6452ea5d7e34a0af00bb6e8dd231d6 Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Wed, 14 Nov 2018 14:56:41 +0000 Subject: [PATCH 05/29] Apply review fixes - remove unnecessary request type constants - remove garbage comments - move operator<< helper to separate header - move requests (currently only offset_fetch) to separate header - remove unnecessary typedefs - add missing documentation - improve GeneratorMap Signed-off-by: Adam Kotwasinski --- source/extensions/extensions_build_config.bzl | 1 - source/extensions/filters/network/kafka/BUILD | 10 +- .../filters/network/kafka/debug_helpers.h | 39 +++ .../filters/network/kafka/kafka_protocol.h | 108 +----- .../filters/network/kafka/kafka_request.cc | 61 ++-- .../filters/network/kafka/kafka_request.h | 314 +++++++----------- .../filters/network/kafka/kafka_types.h | 42 ++- .../filters/network/kafka/message.h | 2 +- .../network/kafka/messages/offset_commit.h | 154 +++++++++ .../extensions/filters/network/kafka/parser.h | 38 ++- .../filters/network/kafka/request_codec.cc | 10 +- .../filters/network/kafka/request_codec.h | 7 +- .../filters/network/kafka/serialization.h | 254 ++++++++------ .../network/kafka/kafka_request_test.cc | 27 +- .../network/kafka/request_codec_test.cc | 11 +- .../network/kafka/serialization_test.cc | 58 ++-- 16 files changed, 623 insertions(+), 513 deletions(-) create mode 100644 source/extensions/filters/network/kafka/debug_helpers.h create mode 100644 source/extensions/filters/network/kafka/messages/offset_commit.h diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index c168a3cd0c7f1..5fb01c97c821c 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -65,7 +65,6 @@ EXTENSIONS = { "envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", "envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", "envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", - "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:config", "envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", "envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", "envoy.filters.network.rbac": "//source/extensions/filters/network/rbac:config", diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 218e7fc5076a2..8e2c5420f6cc7 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -11,10 +11,6 @@ load( envoy_package() -envoy_cc_library( - name = "config", -) - envoy_cc_library( name = "kafka_request_codec_lib", srcs = ["request_codec.cc"], @@ -31,7 +27,11 @@ envoy_cc_library( envoy_cc_library( name = "kafka_request_lib", srcs = ["kafka_request.cc"], - hdrs = ["kafka_request.h"], + hdrs = [ + "debug_helpers.h", + "kafka_request.h", + "messages/offset_commit.h", + ], deps = [ ":parser_lib", ":serialization_lib", diff --git a/source/extensions/filters/network/kafka/debug_helpers.h b/source/extensions/filters/network/kafka/debug_helpers.h new file mode 100644 index 0000000000000..17224c3baa54a --- /dev/null +++ b/source/extensions/filters/network/kafka/debug_helpers.h @@ -0,0 +1,39 @@ +#pragma once + +#include + +#include "absl/types/optional.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +// functions present in this header are used by request / response objects to print their fields +// nicely + +template std::ostream& operator<<(std::ostream& os, const std::vector& arg) { + os << "["; + for (auto iter = arg.begin(); iter != arg.end(); iter++) { + if (iter != arg.begin()) { + os << ", "; + } + os << *iter; + } + os << "]"; + return os; +} + +template std::ostream& operator<<(std::ostream& os, const absl::optional& arg) { + if (arg.has_value()) { + os << *arg; + } else { + os << ""; + } + return os; +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_protocol.h b/source/extensions/filters/network/kafka/kafka_protocol.h index f6452abd3e859..ca785f14c1500 100644 --- a/source/extensions/filters/network/kafka/kafka_protocol.h +++ b/source/extensions/filters/network/kafka/kafka_protocol.h @@ -11,110 +11,12 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// from http://kafka.apache.org/protocol.html#protocol_api_keys -enum RequestType : INT16 { - Produce = 0, - Fetch = 1, - ListOffsets = 2, - Metadata = 3, - LeaderAndIsr = 4, - StopReplica = 5, - UpdateMetadata = 6, - ControlledShutdown = 7, +/** + * Kafka request type identifier + * @see http://kafka.apache.org/protocol.html#protocol_api_keys + */ +enum RequestType : int16_t { OffsetCommit = 8, - OffsetFetch = 9, - FindCoordinator = 10, - JoinGroup = 11, - Heartbeat = 12, - LeaveGroup = 13, - SyncGroup = 14, - DescribeGroups = 15, - ListGroups = 16, - SaslHandshake = 17, - ApiVersions = 18, - CreateTopics = 19, - DeleteTopics = 20, - DeleteRecords = 21, - InitProducerId = 22, - OffsetForLeaderEpoch = 23, - AddPartitionsToTxn = 24, - AddOffsetsToTxn = 25, - EndTxn = 26, - WriteTxnMarkers = 27, - TxnOffsetCommit = 28, - DescribeAcls = 29, - CreateAcls = 30, - DeleteAcls = 31, - DescribeConfigs = 32, - AlterConfigs = 33, - AlterReplicaLogDirs = 34, - DescribeLogDirs = 35, - SaslAuthenticate = 36, - CreatePartitions = 37, - CreateDelegationToken = 38, - RenewDelegationToken = 39, - ExpireDelegationToken = 40, - DescribeDelegationToken = 41, - DeleteGroups = 42 -}; - -struct RequestSpec { - const INT16 api_key_; - const std::string name_; -}; - -struct KafkaRequest { - - // clang-format off - static const std::vector& requests() { - CONSTRUCT_ON_FIRST_USE( - std::vector, - {RequestType::Produce, "Produce"}, - {RequestType::Fetch, "Fetch"}, - {RequestType::ListOffsets, "ListOffsets"}, - {RequestType::Metadata, "Metadata"}, - {RequestType::LeaderAndIsr, "LeaderAndIsr"}, - {RequestType::StopReplica, "StopReplica"}, - {RequestType::UpdateMetadata, "UpdateMetadata"}, - {RequestType::ControlledShutdown, "ControlledShutdown"}, - {RequestType::OffsetCommit, "OffsetCommit"}, - {RequestType::OffsetFetch, "OffsetFetch"}, - {RequestType::FindCoordinator, "FindCoordinator"}, - {RequestType::JoinGroup, "JoinGroup"}, - {RequestType::Heartbeat, "Heartbeat"}, - {RequestType::LeaveGroup, "LeaveGroup"}, - {RequestType::SyncGroup, "SyncGroup"}, - {RequestType::DescribeGroups, "DescribeGroups"}, - {RequestType::ListGroups, "ListGroups"}, - {RequestType::SaslHandshake, "SaslHandshake"}, - {RequestType::ApiVersions, "ApiVersions"}, - {RequestType::CreateTopics, "CreateTopics"}, - {RequestType::DeleteTopics, "DeleteTopics"}, - {RequestType::DeleteRecords, "DeleteRecords"}, - {RequestType::InitProducerId, "InitProducerId"}, - {RequestType::OffsetForLeaderEpoch, "OffsetForLeaderEpoch"}, - {RequestType::AddPartitionsToTxn, "AddPartitionsToTxn"}, - {RequestType::AddOffsetsToTxn, "AddOffsetsToTxn"}, - {RequestType::EndTxn, "EndTxn"}, - {RequestType::WriteTxnMarkers, "WriteTxnMarkers"}, - {RequestType::TxnOffsetCommit, "TxnOffsetCommit"}, - {RequestType::DescribeAcls, "DescribeAcls"}, - {RequestType::CreateAcls, "CreateAcls"}, - {RequestType::DeleteAcls, "DeleteAcls"}, - {RequestType::DescribeConfigs, "DescribeConfigs"}, - {RequestType::AlterConfigs, "AlterConfigs"}, - {RequestType::AlterReplicaLogDirs, "AlterReplicaLogDirs"}, - {RequestType::DescribeLogDirs, "DescribeLogDirs"}, - {RequestType::SaslAuthenticate, "SaslAuthenticate"}, - {RequestType::CreatePartitions, "CreatePartitions"}, - {RequestType::CreateDelegationToken, "CreateDelegationToken"}, - {RequestType::RenewDelegationToken, "RenewDelegationToken"}, - {RequestType::ExpireDelegationToken, "ExpireDelegationToken"}, - {RequestType::DescribeDelegationToken, "DescribeDelegationToken"}, - {RequestType::DeleteGroups, "DeleteGroups"} - ); - } - // clang-format on }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc index d308cf80514fb..8d1ac4381e60b 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -1,6 +1,7 @@ #include "extensions/filters/network/kafka/kafka_request.h" #include "extensions/filters/network/kafka/kafka_protocol.h" +#include "extensions/filters/network/kafka/messages/offset_commit.h" #include "extensions/filters/network/kafka/parser.h" namespace Envoy { @@ -8,24 +9,27 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === REQUEST PARSER MAPPING (REQUEST TYPE => PARSER) ========================= - -GeneratorMap computeGeneratorMap(std::vector specs) { - GeneratorMap result; +// helper function that generates a map from specs looking like { api_key, api_versions... } +GeneratorMap computeGeneratorMap(const GeneratorMap& original, + const std::vector specs) { + GeneratorMap result{original}; for (auto& spec : specs) { - auto generators = result[spec.api_key_]; - if (!generators) { - generators = std::make_shared>(); - result[spec.api_key_] = generators; - } - for (INT16 api_version : spec.api_versions_) { - (*generators)[api_version] = spec.generator_; + auto& generators = result[spec.api_key_]; + for (int16_t api_version : spec.api_versions_) { + generators[api_version] = spec.generator_; } } return result; } +RequestParserResolver::RequestParserResolver(const std::vector arg) + : generators_{computeGeneratorMap({}, arg)} {}; + +RequestParserResolver::RequestParserResolver(const RequestParserResolver& original, + const std::vector arg) + : generators_{computeGeneratorMap(original.generators_, arg)} {}; + #define PARSER_SPEC(REQUEST_NAME, PARSER_VERSION, ...) \ ParserSpec { \ RequestType::REQUEST_NAME, {__VA_ARGS__}, [](RequestContextSharedPtr arg) -> ParserSharedPtr { \ @@ -34,32 +38,37 @@ GeneratorMap computeGeneratorMap(std::vector specs) { } const RequestParserResolver RequestParserResolver::KAFKA_0_11{{ - PARSER_SPEC(OffsetCommit, V0, 0), - PARSER_SPEC(OffsetCommit, V1, 1), + PARSER_SPEC(OffsetCommit, V0, 0), PARSER_SPEC(OffsetCommit, V1, 1), // XXX(adam.kotwasinski) missing request types here }}; -ParserSharedPtr RequestParserResolver::createParser(INT16 api_key, INT16 api_version, +const RequestParserResolver RequestParserResolver::KAFKA_1_0{ + RequestParserResolver::KAFKA_0_11, + { + // XXX(adam.kotwasinski) missing request types & versions here + }}; + +ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, RequestContextSharedPtr context) const { + + // api_key const auto api_versions_ptr = generators_.find(api_key); - // unknown api_key if (generators_.end() == api_versions_ptr) { - return std::make_shared(context); + return std::make_shared(context); } - const auto api_versions = api_versions_ptr->second; + const std::unordered_map& api_versions = api_versions_ptr->second; - // unknown api_version - const auto generator = api_versions->find(api_version); - if (api_versions->end() == generator) { - return std::make_shared(context); + // api_version + const auto generator_ptr = api_versions.find(api_version); + if (api_versions.end() == generator_ptr) { + return std::make_shared(context); } // found matching parser generator, create parser - return generator->second(context); + const GeneratorFunction generator = generator_ptr->second; + return generator(context); } -// === HEADER PARSERS ========================================================== - ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining) { buffer_.feed(buffer, remaining); if (buffer_.ready()) { @@ -85,9 +94,7 @@ ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remainin } } -// === UNKNOWN REQUEST ========================================================= - -ParseResponse SentinelConsumer::parse(const char*& buffer, uint64_t& remaining) { +ParseResponse SentinelParser::parse(const char*& buffer, uint64_t& remaining) { const size_t min = std::min(context_->remaining_request_size_, remaining); buffer += min; remaining -= min; diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 6727dd2a84204..0f8732717f652 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -6,6 +6,7 @@ #include "common/common/assert.h" +#include "extensions/filters/network/kafka/debug_helpers.h" #include "extensions/filters/network/kafka/kafka_protocol.h" #include "extensions/filters/network/kafka/parser.h" #include "extensions/filters/network/kafka/serialization.h" @@ -15,36 +16,15 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === VECTOR ================================================================== - -template std::ostream& operator<<(std::ostream& os, const std::vector& arg) { - os << "["; - for (auto iter = arg.begin(); iter != arg.end(); iter++) { - if (iter != arg.begin()) { - os << ", "; - } - os << *iter; - } - os << "]"; - return os; -} - -template std::ostream& operator<<(std::ostream& os, const absl::optional& arg) { - if (arg.has_value()) { - os << *arg; - } else { - os << ""; - } - return os; -} - -// === REQUEST HEADER ========================================================== - +/** + * Represents fields that are present in every Kafka request message + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ struct RequestHeader { - INT16 api_key_; - INT16 api_version_; - INT32 correlation_id_; - NULLABLE_STRING client_id_; + int16_t api_key_; + int16_t api_version_; + int32_t correlation_id_; + NullableString client_id_; bool operator==(const RequestHeader& rhs) const { return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && @@ -58,8 +38,11 @@ struct RequestHeader { }; }; +/** + * Context that is shared between parsers that are handling the same single message + */ struct RequestContext { - INT32 remaining_request_size_{0}; + int32_t remaining_request_size_{0}; RequestHeader request_header_{}; friend std::ostream& operator<<(std::ostream& os, const RequestContext& arg) { @@ -70,52 +53,74 @@ struct RequestContext { typedef std::shared_ptr RequestContextSharedPtr; -// === REQUEST PARSER MAPPING (REQUEST TYPE => PARSER) ========================= - -// a function generating a parser with given context +/** + * Function generating a parser with given context + */ typedef std::function GeneratorFunction; -// two-level map: api_key -> api_version -> generator function -typedef std::unordered_map>> - GeneratorMap; +/** + * Structure responsible for mapping [api_key, api_version] -> GeneratorFunction + */ +typedef std::unordered_map> GeneratorMap; +/** + * Trivial structure specifying which generator function should be used for which api_key & + * api_version + */ struct ParserSpec { - const INT16 api_key_; - const std::vector api_versions_; + const int16_t api_key_; + const std::vector api_versions_; const GeneratorFunction generator_; }; -// helper function that generates a map from specs looking like { api_key, api_versions... } -GeneratorMap computeGeneratorMap(std::vector arg); - /** - * Provides the parser that is responsible for consuming the request-specific data + * Configuration object + * Resolves the parser that will be responsible for consuming the request-specific data * In other words: provides (api_key, api_version) -> Parser function */ class RequestParserResolver { public: - RequestParserResolver(std::vector arg) : generators_{computeGeneratorMap(arg)} {}; + RequestParserResolver(const std::vector arg); + RequestParserResolver(const RequestParserResolver& original, const std::vector arg); virtual ~RequestParserResolver() = default; - virtual ParserSharedPtr createParser(INT16 api_key, INT16 api_version, + /** + * Creates a parser that is going to process data specific for given api_key & api_version + * @param api_key request type + * @param api_version request version + * @param context context to be used by parser + * @return parser that is capable of processing data for given request type & version + */ + virtual ParserSharedPtr createParser(int16_t api_key, int16_t api_version, RequestContextSharedPtr context) const; + /** + * Request versions handled by Kafka up to 0.11 + */ static const RequestParserResolver KAFKA_0_11; + /** + * Request versions handled by Kafka up to 1.0 + */ + static const RequestParserResolver KAFKA_1_0; + private: GeneratorMap generators_; }; -// === INITIAL PARSERS ========================================================= - /** - * Request start parser just consumes the length of request + * Request parser responsible for consuming request length and setting up context with this data + * @see http://kafka.apache.org/protocol.html#protocol_common */ class RequestStartParser : public Parser { public: RequestStartParser(const RequestParserResolver& parser_resolver) : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; + /** + * Consumes INT32 bytes as request length and updates the context with that value + * @return RequestHeaderParser instance to process request header + */ ParseResponse parse(const char*& buffer, uint64_t& remaining); const RequestContextSharedPtr contextForTest() const { return context_; } @@ -126,17 +131,27 @@ class RequestStartParser : public Parser { Int32Buffer buffer_; }; +/** + * Buffer that gets filled in with request header data + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ class RequestHeaderBuffer : public CompositeBuffer {}; /** - * Request header parser consumes request header + * Parser responsible for computing request header and updating the context with data resolved + * On a successful parse uses resolved data (api_key & api_version) to determine next parser. + * @see http://kafka.apache.org/protocol.html#protocol_messages */ class RequestHeaderParser : public Parser { public: RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) : parser_resolver_{parser_resolver}, context_{context} {}; + /** + * Uses data provided to compute request header + * @return Parser instance responsible for processing rest of the message + */ ParseResponse parse(const char*& buffer, uint64_t& remaining); const RequestContextSharedPtr contextForTest() const { return context_; } @@ -147,12 +162,12 @@ class RequestHeaderParser : public Parser { RequestHeaderBuffer buffer_; }; -// === BUFFERED PARSER ========================================================= - /** * Buffered parser uses a single buffer to construct a response * This parser is responsible for consuming request-specific data (e.g. topic names) and always * returns a parsed message + * @param RT request class + * @param BT buffer type corresponding to request class */ template class BufferedParser : public Parser { public: @@ -161,7 +176,7 @@ template class BufferedParser : public Parser { protected: RequestContextSharedPtr context_; - BT buffer_; + BT buffer_; // underlying request-specific buffer }; template @@ -180,8 +195,11 @@ ParseResponse BufferedParser::parse(const char*& buffer, uint64_t& remai } } -// names of Buffers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names - +/** + * Macro defining RequestParser that uses the underlying Buffer + * Aware of versioning + * Names of Buffers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names + */ #define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ class REQUEST_TYPE##VERSION##Parser \ : public BufferedParser { \ @@ -189,180 +207,68 @@ ParseResponse BufferedParser::parse(const char*& buffer, uint64_t& remai REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : BufferedParser{ctx} {}; \ }; -// === ABSTRACT REQUEST ======================================================== - +/** + * Abstract Kafka request + * Contains data present in every request + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ class Request : public Message { public: /** * Request header fields need to be initialized by user in case of newly created requests */ - Request(INT16 api_key) : request_header_{api_key, 0, 0, ""} {}; + Request(int16_t api_key) : request_header_{api_key, 0, 0, ""} {}; Request(const RequestHeader& request_header) : request_header_{request_header} {}; RequestHeader& header() { return request_header_; } - INT16& apiVersion() { return request_header_.api_version_; } - INT16 apiVersion() const { return request_header_.api_version_; } + int16_t& apiVersion() { return request_header_.api_version_; } + int16_t apiVersion() const { return request_header_.api_version_; } - INT32& correlationId() { return request_header_.correlation_id_; } + int32_t& correlationId() { return request_header_.correlation_id_; } - NULLABLE_STRING& clientId() { return request_header_.client_id_; } + NullableString& clientId() { return request_header_.client_id_; } - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + /** + * Encodes given request into a buffer, with any extra configuration carried by the context + */ + size_t encode(Buffer::Instance& dst, EncodingContext& context) const { size_t written{0}; - written += encoder.encode(request_header_.api_key_, dst); - written += encoder.encode(request_header_.api_version_, dst); - written += encoder.encode(request_header_.correlation_id_, dst); - written += encoder.encode(request_header_.client_id_, dst); - written += encodeDetails(dst, encoder); + written += context.encode(request_header_.api_key_, dst); + written += context.encode(request_header_.api_version_, dst); + written += context.encode(request_header_.correlation_id_, dst); + written += context.encode(request_header_.client_id_, dst); + written += encodeDetails(dst, context); return written; } + /** + * Pretty-prints given request into a stream + */ std::ostream& print(std::ostream& os) const override final { os << request_header_ << " "; // not very pretty return printDetails(os); } protected: + /** + * Encodes request-specific data into a buffer + */ virtual size_t encodeDetails(Buffer::Instance&, EncodingContext&) const PURE; + /** + * Prints request-specific data into a stream + */ virtual std::ostream& printDetails(std::ostream&) const PURE; RequestHeader request_header_; }; -// === OFFSET COMMIT (8) ======================================================= - -struct OffsetCommitPartition { - const INT32 partition_; - const INT64 offset_; - const INT64 timestamp_; // only v1 - const NULLABLE_STRING metadata_; - - // v0 *and* v2 - OffsetCommitPartition(INT32 partition, INT64 offset, NULLABLE_STRING metadata) - : partition_{partition}, offset_{offset}, timestamp_{-1}, metadata_{metadata} {}; - - // v1 - OffsetCommitPartition(INT32 partition, INT64 offset, INT64 timestamp, NULLABLE_STRING metadata) - : partition_{partition}, offset_{offset}, timestamp_{timestamp}, metadata_{metadata} {}; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - if (encoder.apiVersion() == 1) { - written += encoder.encode(timestamp_, dst); - } - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitPartition& rhs) const { - return partition_ == rhs.partition_ && offset_ == rhs.offset_ && timestamp_ == rhs.timestamp_ && - metadata_ == rhs.metadata_; - }; - - friend std::ostream& operator<<(std::ostream& os, const OffsetCommitPartition& arg) { - return os << "{partition=" << arg.partition_ << ", offset=" << arg.offset_ - << ", timestamp=" << arg.timestamp_ << ", metadata=" << arg.metadata_ << "}"; - } -}; - -struct OffsetCommitTopic { - const STRING topic_; - const NULLABLE_ARRAY partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; - - friend std::ostream& operator<<(std::ostream& os, const OffsetCommitTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions_=" << arg.partitions_ << "}"; - } -}; - -class OffsetCommitRequest : public Request { -public: - // v0 - OffsetCommitRequest(STRING group_id, NULLABLE_ARRAY topics) - : OffsetCommitRequest(group_id, -1, "", -1, topics){}; - - // v1 - OffsetCommitRequest(STRING group_id, INT32 group_generation_id, STRING member_id, - NULLABLE_ARRAY topics) - : OffsetCommitRequest(group_id, group_generation_id, member_id, -1, topics){}; - - // v2 .. v3 - OffsetCommitRequest(STRING group_id, INT32 group_generation_id, STRING member_id, - INT64 retention_time, NULLABLE_ARRAY topics) - : Request{RequestType::OffsetCommit}, group_id_{group_id}, - group_generation_id_{group_generation_id}, member_id_{member_id}, - retention_time_{retention_time}, topics_{topics} {}; - - bool operator==(const OffsetCommitRequest& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && - group_generation_id_ == rhs.group_generation_id_ && member_id_ == rhs.member_id_ && - retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - if (encoder.apiVersion() >= 1) { - written += encoder.encode(group_generation_id_, dst); - written += encoder.encode(member_id_, dst); - } - if (encoder.apiVersion() >= 2) { - written += encoder.encode(retention_time_, dst); - } - written += encoder.encode(topics_, dst); - return written; - } - - std::ostream& printDetails(std::ostream& os) const override { - return os << "{group_id=" << group_id_ << ", group_generation_id=" << group_generation_id_ - << ", member_id=" << member_id_ << ", retention_time=" << retention_time_ - << ", topics=" << topics_ << "}"; - } - -private: - const STRING group_id_; - const INT32 group_generation_id_; // since v1 - const STRING member_id_; // since v1 - const INT64 retention_time_; // since v2 - const NULLABLE_ARRAY topics_; -}; - -// clang-format off -class OffsetCommitPartitionV0Buffer : public CompositeBuffer {}; -class OffsetCommitPartitionV0ArrayBuffer : public ArrayBuffer {}; -class OffsetCommitTopicV0Buffer : public CompositeBuffer {}; -class OffsetCommitTopicV0ArrayBuffer : public ArrayBuffer {}; - -class OffsetCommitPartitionV1Buffer : public CompositeBuffer {}; -class OffsetCommitPartitionV1ArrayBuffer : public ArrayBuffer {}; -class OffsetCommitTopicV1Buffer : public CompositeBuffer {}; -class OffsetCommitTopicV1ArrayBuffer : public ArrayBuffer {}; - -class OffsetCommitRequestV0Buffer : public CompositeBuffer {}; -class OffsetCommitRequestV1Buffer : public CompositeBuffer {}; - -DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); -DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); -// clang-format on - -// === UNKNOWN REQUEST ========================================================= - +/** + * Request that did not have api_key & api_version that could be matched with any of + * request-specific parsers + */ class UnknownRequest : public Request { public: UnknownRequest(const RequestHeader& request_header) : Request{request_header} {}; @@ -380,10 +286,18 @@ class UnknownRequest : public Request { } }; -// ignores data until the end of request (contained in context_) -class SentinelConsumer : public Parser { +/** + * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported + * api_key & api_version It does not attempt to capture any data, just throws it away until end of + * message + */ +class SentinelParser : public Parser { public: - SentinelConsumer(RequestContextSharedPtr context) : context_{context} {}; + SentinelParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Returns UnknownRequest + */ ParseResponse parse(const char*& buffer, uint64_t& remaining) override; const RequestContextSharedPtr contextForTest() const { return context_; } diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h index f5c188e2dff59..0fe1d9591b548 100644 --- a/source/extensions/filters/network/kafka/kafka_types.h +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -10,20 +10,34 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -typedef int8_t INT8; -typedef int16_t INT16; -typedef int32_t INT32; -typedef int64_t INT64; -typedef uint32_t UINT32; -typedef bool BOOLEAN; - -typedef std::string STRING; -typedef absl::optional NULLABLE_STRING; - -typedef std::vector BYTES; -typedef absl::optional NULLABLE_BYTES; - -template using NULLABLE_ARRAY = absl::optional>; +/** + * Represents a sequence of characters or null. For non-null strings, first the length N is given as + * an INT16. Then N bytes follow which are the UTF-8 encoding of the character sequence. A null + * value is encoded with length of -1 and there are no following bytes. + */ +typedef absl::optional NullableString; + +/** + * Represents a raw sequence of bytes. + * First the length N is given as an INT32. Then N bytes follow. + */ +typedef std::vector Bytes; + +/** + * Represents a raw sequence of bytes or null. For non-null values, first the length N is given as + * an INT32. Then N bytes follow. A null value is encoded with length of -1 and there are no + * following bytes. + */ +typedef absl::optional NullableBytes; + +/** + * Represents a sequence of objects of a given type T. + * Type T can be either a primitive type (e.g. STRING) or a structure. + * First, the length N is given as an INT32. + * Then N instances of type T follow. + * A null array is represented with a length of -1. + */ +template using NullableArray = absl::optional>; } // namespace Kafka } // namespace NetworkFilters diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h index 7d53597745508..a929698844563 100644 --- a/source/extensions/filters/network/kafka/message.h +++ b/source/extensions/filters/network/kafka/message.h @@ -10,7 +10,7 @@ namespace NetworkFilters { namespace Kafka { /** - * Abstract message + * Abstract message (that can be either request or response) */ class Message { public: diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h new file mode 100644 index 0000000000000..82e9b8bebf3db --- /dev/null +++ b/source/extensions/filters/network/kafka/messages/offset_commit.h @@ -0,0 +1,154 @@ +#pragma once + +#include "extensions/filters/network/kafka/kafka_request.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Generic description : http://kafka.apache.org/protocol.html#The_Messages_OffsetCommit + */ + +/** + * Holds the partition node (leaf) + */ +struct OffsetCommitPartition { + const int32_t partition_; + const int64_t offset_; + const int64_t timestamp_; // only v1 + const NullableString metadata_; + + // v0 *and* v2 + OffsetCommitPartition(int32_t partition, int64_t offset, NullableString metadata) + : partition_{partition}, offset_{offset}, timestamp_{-1}, metadata_{metadata} {}; + + // v1 + OffsetCommitPartition(int32_t partition, int64_t offset, int64_t timestamp, + NullableString metadata) + : partition_{partition}, offset_{offset}, timestamp_{timestamp}, metadata_{metadata} {}; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + if (encoder.apiVersion() == 1) { + written += encoder.encode(timestamp_, dst); + } + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitPartition& rhs) const { + return partition_ == rhs.partition_ && offset_ == rhs.offset_ && timestamp_ == rhs.timestamp_ && + metadata_ == rhs.metadata_; + }; + + friend std::ostream& operator<<(std::ostream& os, const OffsetCommitPartition& arg) { + return os << "{partition=" << arg.partition_ << ", offset=" << arg.offset_ + << ", timestamp=" << arg.timestamp_ << ", metadata=" << arg.metadata_ << "}"; + } +}; + +/** + * Holds the topic node (contains multiple partitions) + */ +struct OffsetCommitTopic { + const std::string topic_; + const NullableArray partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitTopic& rhs) const { + return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; + }; + + friend std::ostream& operator<<(std::ostream& os, const OffsetCommitTopic& arg) { + return os << "{topic=" << arg.topic_ << ", partitions_=" << arg.partitions_ << "}"; + } +}; + +/** + * Holds the request (contains multiple topics) + */ +class OffsetCommitRequest : public Request { +public: + // v0 + OffsetCommitRequest(std::string group_id, NullableArray topics) + : OffsetCommitRequest(group_id, -1, "", -1, topics){}; + + // v1 + OffsetCommitRequest(std::string group_id, int32_t group_generation_id, std::string member_id, + NullableArray topics) + : OffsetCommitRequest(group_id, group_generation_id, member_id, -1, topics){}; + + // v2 .. v3 + OffsetCommitRequest(std::string group_id, int32_t group_generation_id, std::string member_id, + int64_t retention_time, NullableArray topics) + : Request{RequestType::OffsetCommit}, group_id_{group_id}, + group_generation_id_{group_generation_id}, member_id_{member_id}, + retention_time_{retention_time}, topics_{topics} {}; + + bool operator==(const OffsetCommitRequest& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && + group_generation_id_ == rhs.group_generation_id_ && member_id_ == rhs.member_id_ && + retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + if (encoder.apiVersion() >= 1) { + written += encoder.encode(group_generation_id_, dst); + written += encoder.encode(member_id_, dst); + } + if (encoder.apiVersion() >= 2) { + written += encoder.encode(retention_time_, dst); + } + written += encoder.encode(topics_, dst); + return written; + } + + std::ostream& printDetails(std::ostream& os) const override { + return os << "{group_id=" << group_id_ << ", group_generation_id=" << group_generation_id_ + << ", member_id=" << member_id_ << ", retention_time=" << retention_time_ + << ", topics=" << topics_ << "}"; + } + +private: + const std::string group_id_; + const int32_t group_generation_id_; // since v1 + const std::string member_id_; // since v1 + const int64_t retention_time_; // since v2 + const NullableArray topics_; +}; + +// clang-format off +class OffsetCommitPartitionV0Buffer : public CompositeBuffer {}; +class OffsetCommitPartitionV0ArrayBuffer : public ArrayBuffer {}; +class OffsetCommitTopicV0Buffer : public CompositeBuffer {}; +class OffsetCommitTopicV0ArrayBuffer : public ArrayBuffer {}; + +class OffsetCommitPartitionV1Buffer : public CompositeBuffer {}; +class OffsetCommitPartitionV1ArrayBuffer : public ArrayBuffer {}; +class OffsetCommitTopicV1Buffer : public CompositeBuffer {}; +class OffsetCommitTopicV1ArrayBuffer : public ArrayBuffer {}; + +class OffsetCommitRequestV0Buffer : public CompositeBuffer {}; +class OffsetCommitRequestV1Buffer : public CompositeBuffer {}; + +DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); +DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); +// clang-format on + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h index d9b11ae2b2ec2..8ceade142b6b5 100644 --- a/source/extensions/filters/network/kafka/parser.h +++ b/source/extensions/filters/network/kafka/parser.h @@ -12,25 +12,59 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === PARSER ================================================================== - class ParseResponse; +/** + * Parser is responsible for consuming data relevant to some part of a message, and then returning + * the decision how the parsing should continue impl note: better name could be Consumer, but really + * don't want to use that word considering that it's so prevalent in Kafka world; suggestions + * welcome + */ class Parser : public Logger::Loggable { public: virtual ~Parser() = default; + /** + * Submit data to be processed by parser, will consume as much data as it is necessary to reach + * the conclusion what should be the next parse step + * @param buffer data pointer, will be updated by parser + * @param remaining remaining data in buffer, will be updated by parser + * @return parse status - decision what should be done with current parser (keep/replace) + */ virtual ParseResponse parse(const char*& buffer, uint64_t& remaining) PURE; }; typedef std::shared_ptr ParserSharedPtr; +/** + * Three-state holder representing one of: + * - parser still needs data (`stillWaiting`) + * - parser is finished, and following parser should be used to process the rest of data + * (`nextParser`) + * - parser is finished, and fully-parsed message is attached (`parsedMessage`) + */ class ParseResponse { public: + /** + * Constructs a response that states that parser still needs data and should not be replaced + */ static ParseResponse stillWaiting() { return {nullptr, nullptr}; } + + /** + * Constructs a response that states that parser is finished and should be replaced by given + * parser + */ static ParseResponse nextParser(ParserSharedPtr next_parser) { return {next_parser, nullptr}; }; + + /** + * Constructs a response that states that parser is finished, the message is ready, and parsing + * can start anew for next message + */ static ParseResponse parsedMessage(MessageSharedPtr message) { return {nullptr, message}; }; + /** + * If response contains a next parser or the fully parsed message + */ bool hasData() const { return (next_parser_ != nullptr) || (message_ != nullptr); } private: diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 9ec598e30d0a7..3ee2754ed5e80 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -9,8 +9,6 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === DECODER ================================================================= - void RequestDecoder::onData(Buffer::Instance& data) { uint64_t num_slices = data.getRawSlices(nullptr, 0); Buffer::RawSlice slices[num_slices]; @@ -46,14 +44,12 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl } } -// === ENCODER ================================================================= - void RequestEncoder::encode(const Request& message) { EncodingContext encoder{message.apiVersion()}; Buffer::OwnedImpl data_buffer; - INT32 data_len = encoder.encode(message, data_buffer); // encode data computing data length - encoder.encode(data_len, output_); // encode data length into result - output_.add(data_buffer); // copy data into result + int32_t data_len = encoder.encode(message, data_buffer); // encode data computing data length + encoder.encode(data_len, output_); // encode data length into result + output_.add(data_buffer); // copy data into result } } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index ebd797ea1a75a..8078a87e16d65 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -13,8 +13,6 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === DECODER ================================================================= - /** * Invoked when request is successfully decoded */ @@ -52,8 +50,9 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable { public: RequestEncoder(Buffer::Instance& output) : output_(output) {} diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index cf1649c2749b3..10f3b4ba42d3a 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -19,31 +19,40 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// ============================================================================= -// === DESERIALIZERS =========================================================== -// ============================================================================= - /** - * The general idea of Buffer is that it can be feed-ed data until it is ready - * When true == ready(), it is safe to call get() - * Further feed()-ing should have no effect on a buffer - * (should return 0 and not move buffer/remaining) + * Deserializer is a stateful entity that constructs a result from bytes provided + * It can be feed()-ed data until it is ready, filling the internal store + * When ready(), it is safe to call get() to transform the internally stored bytes into result + * Further feed()-ing should have no effect on a buffer (should return 0 and not move + * buffer/remaining) */ - -// === ABSTRACT DESERIALIZER =================================================== - template class Deserializer { public: virtual ~Deserializer() = default; + /** + * Submit data to be processed, will consume as much data as it is necessary. + * Invoking this method when deserializer is ready has no effect (consumes 0 bytes) + * @param buffer data pointer, will be updated if data is consumed + * @param remaining remaining data in buffer, will be updated if data is consumed + * @return bytes consumed + */ virtual size_t feed(const char*& buffer, uint64_t& remaining) PURE; + + /** + * Whether deserializer has consumed enough data to return result + */ virtual bool ready() const PURE; + + /** + * Returns the entity that is represented by bytes stored in this deserializer + * Should be only called when deserializer is ready + */ virtual T get() const PURE; }; -// === INT BUFFERS ============================================================= - /** + * Generic integer deserializer (uses array of sizeof(T) bytes) * The values are encoded in network byte order (big-endian). */ template class IntBuffer : public Deserializer { @@ -73,59 +82,74 @@ template class IntBuffer : public Deserializer { bool ready_; }; -class Int8Buffer : public IntBuffer { +/** + * Deserializer for int8_t + */ +class Int8Buffer : public IntBuffer { public: - INT8 get() const { - INT8 result; + int8_t get() const { + int8_t result; memcpy(&result, buf_, sizeof(result)); return result; } }; -class Int16Buffer : public IntBuffer { +/** + * Deserializer for int16_t + */ +class Int16Buffer : public IntBuffer { public: - INT16 get() const { - INT16 result; + int16_t get() const { + int16_t result; memcpy(&result, buf_, sizeof(result)); return be16toh(result); } }; -class Int32Buffer : public IntBuffer { +/** + * Deserializer for int32_t + */ +class Int32Buffer : public IntBuffer { public: - INT32 get() const { - INT32 result; + int32_t get() const { + int32_t result; memcpy(&result, buf_, sizeof(result)); return be32toh(result); } }; -class UInt32Buffer : public IntBuffer { +/** + * Deserializer for uint32_t + */ +class UInt32Buffer : public IntBuffer { public: - UINT32 get() const { - UINT32 result; + uint32_t get() const { + uint32_t result; memcpy(&result, buf_, sizeof(result)); return be32toh(result); } }; -class Int64Buffer : public IntBuffer { +/** + * Deserializer for uint64_t + */ +class Int64Buffer : public IntBuffer { public: - INT64 get() const { - INT64 result; + int64_t get() const { + int64_t result; memcpy(&result, buf_, sizeof(result)); return be64toh(result); } }; -// === BOOL BUFFER ============================================================= - /** - * Represents a boolean value in a byte. + * Deserializer of boolean value + * + * Boolean value is stored in a byte. * Values 0 and 1 are used to represent false and true respectively. * When reading a boolean value, any non-zero value is considered true. */ -class BoolBuffer : public Deserializer { +class BoolBuffer : public Deserializer { public: BoolBuffer(){}; @@ -133,21 +157,20 @@ class BoolBuffer : public Deserializer { bool ready() const { return buffer_.ready(); } - BOOLEAN get() const { return 0 != buffer_.get(); } + bool get() const { return 0 != buffer_.get(); } private: Int8Buffer buffer_; }; -// === STRING BUFFER =========================================================== - /** - * Represents a sequence of characters. - * First the length N is given as an INT16. + * Deserializer of string value + * + * First the length N is given as an int16_t. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * Length must not be negative. */ -class StringBuffer : public Deserializer { +class StringBuffer : public Deserializer { public: size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); @@ -161,7 +184,7 @@ class StringBuffer : public Deserializer { if (required_ >= 0) { data_buf_ = std::vector(required_); } else { - throw EnvoyException(fmt::format("invalid STRING length: {}", required_)); + throw EnvoyException(fmt::format("invalid std::string length: {}", required_)); } length_consumed_ = true; } @@ -183,25 +206,26 @@ class StringBuffer : public Deserializer { bool ready() const { return ready_; } - STRING get() const { return std::string(data_buf_.begin(), data_buf_.end()); } + std::string get() const { return std::string(data_buf_.begin(), data_buf_.end()); } private: Int16Buffer length_buf_; bool length_consumed_{false}; - INT16 required_; + int16_t required_; std::vector data_buf_; bool ready_{false}; }; /** - * Represents a sequence of characters or null. - * For non-null strings, first the length N is given as an INT16. + * Deserializer of nullable string value + * + * For non-null strings, first the length N is given as an int16_t. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * A null value is encoded with length of -1 and there are no following bytes. */ -class NullableStringBuffer : public Deserializer { +class NullableStringBuffer : public Deserializer { public: size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); @@ -247,31 +271,34 @@ class NullableStringBuffer : public Deserializer { bool ready() const { return ready_; } - NULLABLE_STRING get() const { + NullableString get() const { return required_ >= 0 ? absl::make_optional(std::string(data_buf_.begin(), data_buf_.end())) : absl::nullopt; } private: - constexpr static INT16 NULL_STRING_LENGTH{-1}; + constexpr static int16_t NULL_STRING_LENGTH{-1}; Int16Buffer length_buf_; bool length_consumed_{false}; - INT16 required_; + int16_t required_; std::vector data_buf_; bool ready_{false}; }; -// === COMPOSITE BUFFER ======================================================== - /** - * Composes several buffers into one. - * The returned value is constructed via { buffer1.get(), buffer2.get() ... } + * Composite deserializer + * Passes data to each of the underlying deserializers (deserializers that are already ready do not + * consume data, so it's safe) Is ready when the last deserializer is ready (which means all + * deserializers before it are ready too) Constructs the result using { buffer1_.get(), + * buffer2_.get() ... } */ template class CompositeBuffer; +// XXX(adam.kotwasinski) I will get rid of this + template class CompositeBuffer : public Deserializer { public: CompositeBuffer(){}; @@ -347,15 +374,20 @@ class CompositeBuffer : public Deserializer { T4 buffer4_; }; -// === ARRAY BUFFER ============================================================ - /** + * Deserializer for array of objects + * First reads the length of the array, then initializes N underlying deserializers of type CT + * After the last of N deserializers is ready, the results of each of them are gathered and put in a + * vector + * @param RT result type returned by deserializer CT + * @param CT underlying deserializer type + * + * Documentation: * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. - * STRING) or a structure. First, the length N is given as an INT32. Then N instances of type T + * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T * follow. A null array is represented with a length of -1. */ - -template class ArrayBuffer : public Deserializer> { +template class ArrayBuffer : public Deserializer> { public: size_t feed(const char*& buffer, uint64_t& remaining) { @@ -401,7 +433,7 @@ template class ArrayBuffer : public Deserializer get() const { + NullableArray get() const { if (NULL_ARRAY_LENGTH != required_) { std::vector result{}; result.reserve(children_.size()); @@ -416,20 +448,19 @@ template class ArrayBuffer : public Deserializer children_; bool children_setup_{false}; bool ready_{false}; }; -// === NULL BUFFER ============================================================= - /** - * Consumes no bytes, used as placeholder + * Trivial deserializer that is always ready, and consumes no bytes + * Used in situations when value is always present and returns a constant */ template class NullBuffer : public Deserializer { public: @@ -440,39 +471,56 @@ template class NullBuffer : public Deserializer { RT get() const { return {}; } }; -// ============================================================================= -// === ENCODER HELPER ========================================================== -// ============================================================================= - /** * Encodes provided argument in Kafka format * In case of primitive types, this is done explicitly as per spec * In case of composite types, this is done by calling 'encode' on provided argument + * + * This object also carries extra information that is used while traversing the request + * structure-tree during encryping (currently api_version, as different request versions serialize + * differently) */ - class EncodingContext { public: - EncodingContext(INT16 api_version) : api_version_{api_version} {}; + EncodingContext(int16_t api_version) : api_version_{api_version} {}; + /** + * Encode given reference in a buffer + * @return bytes written + */ template size_t encode(const T& arg, Buffer::Instance& dst); - template size_t encode(const NULLABLE_ARRAY& arg, Buffer::Instance& dst); + /** + * Encode given array in a buffer + * @return bytes written + */ + template size_t encode(const NullableArray& arg, Buffer::Instance& dst); - INT16 apiVersion() const { return api_version_; } + int16_t apiVersion() const { return api_version_; } private: - const INT16 api_version_; + const int16_t api_version_; }; +/** + * For non-primitive types, call `encode` on them, to delegate the serialization to the entity + * itself + */ template inline size_t EncodingContext::encode(const T& arg, Buffer::Instance& dst) { return arg.encode(dst, *this); } -template <> inline size_t EncodingContext::encode(const INT8& arg, Buffer::Instance& dst) { - dst.add(&arg, sizeof(INT8)); - return sizeof(INT8); +/** + * Encode a single byte + */ +template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Instance& dst) { + dst.add(&arg, sizeof(int8_t)); + return sizeof(int8_t); } +/** + * Encode a N-byte integer, converting to network byte-order + */ #define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ template <> inline size_t EncodingContext::encode(const TYPE& arg, Buffer::Instance& dst) { \ TYPE val = CONVERTER(arg); \ @@ -480,63 +528,81 @@ template <> inline size_t EncodingContext::encode(const INT8& arg, Buffer::Insta return sizeof(TYPE); \ } -ENCODE_NUMERIC_TYPE(INT16, htobe16); -ENCODE_NUMERIC_TYPE(INT32, htobe32); -ENCODE_NUMERIC_TYPE(UINT32, htobe32); -ENCODE_NUMERIC_TYPE(INT64, htobe64); +ENCODE_NUMERIC_TYPE(int16_t, htobe16); +ENCODE_NUMERIC_TYPE(int32_t, htobe32); +ENCODE_NUMERIC_TYPE(uint32_t, htobe32); +ENCODE_NUMERIC_TYPE(int64_t, htobe64); -template <> inline size_t EncodingContext::encode(const BOOLEAN& arg, Buffer::Instance& dst) { - INT8 val = arg; - dst.add(&val, sizeof(INT8)); - return sizeof(INT8); +/** + * Encode boolean as a single byte + */ +template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Instance& dst) { + int8_t val = arg; + dst.add(&val, sizeof(int8_t)); + return sizeof(int8_t); } -template <> inline size_t EncodingContext::encode(const STRING& arg, Buffer::Instance& dst) { - INT16 string_length = arg.length(); +/** + * Encode string as INT16 length + N bytes + */ +template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer::Instance& dst) { + int16_t string_length = arg.length(); size_t header_length = encode(string_length, dst); dst.add(arg.c_str(), string_length); return header_length + string_length; } +/** + * Encode nullable string as INT16 length + N bytes (length = -1 for null) + */ template <> -inline size_t EncodingContext::encode(const NULLABLE_STRING& arg, Buffer::Instance& dst) { +inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instance& dst) { if (arg.has_value()) { return encode(*arg, dst); } else { - INT16 len = -1; + int16_t len = -1; return encode(len, dst); } } -template <> inline size_t EncodingContext::encode(const BYTES& arg, Buffer::Instance& dst) { - INT32 data_length = arg.size(); +/** + * Encode byte array as INT32 length + N bytes + */ +template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Instance& dst) { + int32_t data_length = arg.size(); size_t header_length = encode(data_length, dst); dst.add(arg.data(), arg.size()); return header_length + data_length; } -template <> -inline size_t EncodingContext::encode(const NULLABLE_BYTES& arg, Buffer::Instance& dst) { +/** + * Encode nullable byte array as INT32 length + N bytes (length = -1 for null) + */ +template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buffer::Instance& dst) { if (arg.has_value()) { return encode(*arg, dst); } else { - INT32 len = -1; + int32_t len = -1; return encode(len, dst); } } +/** + * Encode nullable object array as INT32 length + N bytes (length = -1 for null) + */ template -size_t EncodingContext::encode(const NULLABLE_ARRAY& arg, Buffer::Instance& dst) { +size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& dst) { if (arg.has_value()) { - INT32 len = arg->size(); + int32_t len = arg->size(); size_t header_length = encode(len, dst); size_t written{0}; for (const T& el : *arg) { + // for each of array elements, resolve the correct method again written += encode(el, dst); } return header_length + written; } else { - INT32 len = -1; + int32_t len = -1; return encode(len, dst); } } diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index 056d116b7715b..5fafce3765973 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -1,4 +1,5 @@ #include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/messages/offset_commit.h" #include "test/mocks/server/mocks.h" @@ -23,7 +24,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestTypeIsNotRegistered) { // then ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); } TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) { @@ -40,7 +41,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) // then ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); } TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { @@ -79,7 +80,7 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { // given RequestStartParser testee{RequestParserResolver{{}}}; - INT32 request_len = 1234; + int32_t request_len = 1234; encoder_.encode(request_len, buffer()); const char* bytes = getBytes(); @@ -98,7 +99,7 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { class MockRequestParserResolver : public RequestParserResolver { public: MockRequestParserResolver() : RequestParserResolver{{}} {}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(INT16, INT16, RequestContextSharedPtr)); + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); }; TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { @@ -107,15 +108,15 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext const ParserSharedPtr parser{new OffsetCommitRequestV0Parser{nullptr}}; EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); - const INT32 request_len = 1000; + const int32_t request_len = 1000; RequestContextSharedPtr context{new RequestContext()}; context->remaining_request_size_ = request_len; RequestHeaderParser testee{parser_resolver, context}; - const INT16 api_key{1}; - const INT16 api_version{2}; - const INT32 correlation_id{10}; - const NULLABLE_STRING client_id{"aaa"}; + const int16_t api_key{1}; + const int16_t api_version{2}; + const int32_t correlation_id{10}; + const NullableString client_id{"aaa"}; size_t written = 0; written += encoder_.encode(api_key, buffer()); written += encoder_.encode(api_version, buffer()); @@ -141,14 +142,14 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); } -TEST_F(BufferBasedTest, SentinelConsumerShouldConsumeDataUntilEndOfRequest) { +TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { // given - const INT32 request_len = 1000; + const int32_t request_len = 1000; RequestContextSharedPtr context{new RequestContext()}; context->remaining_request_size_ = request_len; - SentinelConsumer testee{context}; + SentinelParser testee{context}; - const BYTES garbage(request_len * 2); + const Bytes garbage(request_len * 2); encoder_.encode(garbage, buffer()); const char* bytes = getBytes(); diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc index ab1a1798f3358..3bd45bbb52375 100644 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -1,3 +1,4 @@ +#include "extensions/filters/network/kafka/messages/offset_commit.h" #include "extensions/filters/network/kafka/request_codec.h" #include "test/mocks/server/mocks.h" @@ -39,11 +40,9 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria return std::dynamic_pointer_cast(receivedMessage); }; -// === OFFSET COMMIT (8) ======================================================= - TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { // given - NULLABLE_ARRAY topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; + NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; OffsetCommitRequest request{"group_id", topics}; request.apiVersion() = 0; request.correlationId() = 10; @@ -60,7 +59,7 @@ TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { // given // partitions have timestamp in v1 only - NULLABLE_ARRAY topics{ + NullableArray topics{ {{"topic1", {{{0, 10, 100, "m1"}, {2, 20, 101, "m2"}}}}, {"topic2", {{{3, 30, 102, "m3"}}}}}}; OffsetCommitRequest request{"group_id", 40, // group_generation_id @@ -78,12 +77,10 @@ TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { ASSERT_EQ(*received, request); } -// === UNKNOWN REQUEST ========================================================= - TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given RequestEncoder serializer{buffer_}; - NULLABLE_ARRAY topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; + NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; OffsetCommitRequest request{"group_id", topics}; request.apiVersion() = 1; request.correlationId() = 42; diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index e26cdc94946a2..f41a4156e2a10 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -12,8 +12,6 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// === EMPTY (FRESHLY INITIALIZED) BUFFER TESTS ================================ - // freshly created buffers should not be ready #define TEST_EmptyBufferShouldNotBeReady(BufferClass) \ TEST(BufferClass, EmptyBufferShouldNotBeReady) { \ @@ -31,13 +29,13 @@ TEST_EmptyBufferShouldNotBeReady(StringBuffer); TEST_EmptyBufferShouldNotBeReady(NullableStringBuffer); TEST(CompositeBuffer, EmptyBufferShouldNotBeReady) { // given - const CompositeBuffer testee{}; + const CompositeBuffer testee{}; // when, then ASSERT_EQ(testee.ready(), false); } TEST(ArrayBuffer, EmptyBufferShouldNotBeReady) { // given - const ArrayBuffer testee{}; + const ArrayBuffer testee{}; // when, then ASSERT_EQ(testee.ready(), false); } @@ -45,14 +43,12 @@ TEST(ArrayBuffer, EmptyBufferShouldNotBeReady) { // Null buffer is a special case, it's always ready and can provide results via 0-arg ctor TEST(NullBuffer, EmptyBufferShouldBeReady) { // given - const NullBuffer testee{}; + const NullBuffer testee{}; // when, then ASSERT_EQ(testee.ready(), true); ASSERT_EQ(testee.get(), 0); } -// === SERIALIZATION / DESERIALIZATION TESTS =================================== - EncodingContext encoder{-1}; // context is not used when serializing primitive types const char* getRawData(const Buffer::OwnedImpl& buffer) { @@ -146,8 +142,6 @@ template void serializeThenDeserializeAndCheckEqualit serializeThenDeserializeAndCheckEqualityWithChunks(expected); } -// === NUMERIC BUFFERS ========================================================= - // macroed out test for numeric buffers #define TEST_BufferShouldDeserialize(BufferClass, DataClass, Value) \ TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ @@ -156,22 +150,20 @@ template void serializeThenDeserializeAndCheckEqualit serializeThenDeserializeAndCheckEquality(value); \ } -TEST_BufferShouldDeserialize(Int8Buffer, INT8, 42); -TEST_BufferShouldDeserialize(Int16Buffer, INT16, 42); -TEST_BufferShouldDeserialize(Int32Buffer, INT32, 42); -TEST_BufferShouldDeserialize(UInt32Buffer, UINT32, 42); -TEST_BufferShouldDeserialize(Int64Buffer, INT64, 42); -TEST_BufferShouldDeserialize(BoolBuffer, BOOLEAN, true); - -// === (NULLABLE) STRING BUFFER ================================================ +TEST_BufferShouldDeserialize(Int8Buffer, int8_t, 42); +TEST_BufferShouldDeserialize(Int16Buffer, int16_t, 42); +TEST_BufferShouldDeserialize(Int32Buffer, int32_t, 42); +TEST_BufferShouldDeserialize(UInt32Buffer, uint32_t, 42); +TEST_BufferShouldDeserialize(Int64Buffer, int64_t, 42); +TEST_BufferShouldDeserialize(BoolBuffer, bool, true); TEST(StringBuffer, ShouldDeserialize) { - const STRING value = "sometext"; + const std::string value = "sometext"; serializeThenDeserializeAndCheckEquality(value); } TEST(StringBuffer, ShouldDeserializeEmptyString) { - const STRING value = ""; + const std::string value = ""; serializeThenDeserializeAndCheckEquality(value); } @@ -180,7 +172,7 @@ TEST(StringBuffer, ShouldThrowOnInvalidLength) { StringBuffer testee; Buffer::OwnedImpl buffer; - INT16 len = -1; + int16_t len = -1; // STRING accepts only >= 0 encoder.encode(len, buffer); uint64_t remaining = 1024; @@ -193,19 +185,19 @@ TEST(StringBuffer, ShouldThrowOnInvalidLength) { TEST(NullableStringBuffer, ShouldDeserializeString) { // given - const NULLABLE_STRING value{"sometext"}; + const NullableString value{"sometext"}; serializeThenDeserializeAndCheckEquality(value); } TEST(NullableStringBuffer, ShouldDeserializeEmptyString) { // given - const NULLABLE_STRING value{""}; + const NullableString value{""}; serializeThenDeserializeAndCheckEquality(value); } TEST(NullableStringBuffer, ShouldDeserializeAbsentString) { // given - const NULLABLE_STRING value = absl::nullopt; + const NullableString value = absl::nullopt; serializeThenDeserializeAndCheckEquality(value); } @@ -214,7 +206,7 @@ TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { NullableStringBuffer testee; Buffer::OwnedImpl buffer; - INT16 len = -2; // -1 is OK for NULLABLE_STRING + int16_t len = -2; // -1 is OK for NULLABLE_STRING encoder.encode(len, buffer); uint64_t remaining = 1024; @@ -225,11 +217,9 @@ TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -// === ARRAY BUFFER ============================================================ - TEST(ArrayBuffer, ShouldConsumeCorrectAmountOfData) { - const NULLABLE_ARRAY value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; - serializeThenDeserializeAndCheckEquality>(value); + const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + serializeThenDeserializeAndCheckEquality>(value); } TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { @@ -237,7 +227,7 @@ TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { ArrayBuffer testee; Buffer::OwnedImpl buffer; - const INT32 len = -2; // -1 is OK for ARRAY + const int32_t len = -2; // -1 is OK for ARRAY encoder.encode(len, buffer); uint64_t remaining = 1024; @@ -248,12 +238,10 @@ TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -// === COMPOSITE BUFFER ======================================================== - struct CompositeBufferResult { - STRING field1_; - NULLABLE_ARRAY field2_; - INT16 field3_; + std::string field1_; + NullableArray field2_; + int16_t field3_; size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { size_t written{0}; @@ -269,7 +257,7 @@ bool operator==(const CompositeBufferResult& lhs, const CompositeBufferResult& r (lhs.field3_ == rhs.field3_); } -typedef CompositeBuffer, +typedef CompositeBuffer, Int16Buffer> TestCompositeBuffer; From 431a64c3ff647be6ea5f0cf1ca2435d80b2b8a28 Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Thu, 15 Nov 2018 15:13:43 +0000 Subject: [PATCH 06/29] Apply review fixes - properly access buffer slices - remove CompositeDeserializer and replace it with expanded classes - documentation Signed-off-by: Adam Kotwasinski --- .../filters/network/kafka/kafka_request.cc | 12 +- .../filters/network/kafka/kafka_request.h | 62 +++++-- .../network/kafka/messages/offset_commit.h | 162 ++++++++++++++++-- .../filters/network/kafka/request_codec.cc | 11 +- .../filters/network/kafka/serialization.h | 159 +++++------------ .../network/kafka/kafka_request_test.cc | 6 +- .../network/kafka/serialization_test.cc | 136 ++++++--------- 7 files changed, 315 insertions(+), 233 deletions(-) diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc index 8d1ac4381e60b..7a26f8bf8835b 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -70,9 +70,9 @@ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api } ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining) { - buffer_.feed(buffer, remaining); - if (buffer_.ready()) { - context_->remaining_request_size_ = buffer_.get(); + request_length_.feed(buffer, remaining); + if (request_length_.ready()) { + context_->remaining_request_size_ = request_length_.get(); return ParseResponse::nextParser( std::make_shared(parser_resolver_, context_)); } else { @@ -81,10 +81,10 @@ ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining } ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remaining) { - context_->remaining_request_size_ -= buffer_.feed(buffer, remaining); + context_->remaining_request_size_ -= deserializer_.feed(buffer, remaining); - if (buffer_.ready()) { - RequestHeader request_header = buffer_.get(); + if (deserializer_.ready()) { + RequestHeader request_header = deserializer_.get(); context_->request_header_ = request_header; ParserSharedPtr next_parser = parser_resolver_.createParser( request_header.api_key_, request_header.api_version_, context_); diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 0f8732717f652..24c0fc0a62801 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -64,8 +64,8 @@ typedef std::function GeneratorFunctio typedef std::unordered_map> GeneratorMap; /** - * Trivial structure specifying which generator function should be used for which api_key & - * api_version + * Trivial structure specifying which generator function should be used + * for which api_key & api_version */ struct ParserSpec { const int16_t api_key_; @@ -128,15 +128,34 @@ class RequestStartParser : public Parser { private: const RequestParserResolver& parser_resolver_; const RequestContextSharedPtr context_; - Int32Buffer buffer_; + Int32Deserializer request_length_; }; /** - * Buffer that gets filled in with request header data + * Deserializer that extracts request header * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class RequestHeaderBuffer : public CompositeBuffer {}; +class RequestHeaderDeserializer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += api_key_.feed(buffer, remaining); + consumed += api_version_.feed(buffer, remaining); + consumed += correlation_id_.feed(buffer, remaining); + consumed += client_id_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return client_id_.ready(); } + RequestHeader get() const { + return {api_key_.get(), api_version_.get(), correlation_id_.get(), client_id_.get()}; + } + +protected: + Int16Deserializer api_key_; + Int16Deserializer api_version_; + Int32Deserializer correlation_id_; + NullableStringDeserializer client_id_; +}; /** * Parser responsible for computing request header and updating the context with data resolved @@ -159,33 +178,34 @@ class RequestHeaderParser : public Parser { private: const RequestParserResolver& parser_resolver_; const RequestContextSharedPtr context_; - RequestHeaderBuffer buffer_; + RequestHeaderDeserializer deserializer_; }; /** - * Buffered parser uses a single buffer to construct a response + * Buffered parser uses a single deserializer to construct a response * This parser is responsible for consuming request-specific data (e.g. topic names) and always * returns a parsed message * @param RT request class - * @param BT buffer type corresponding to request class + * @param BT deserializer type corresponding to request class (should be subclass of + * Deserializer) */ -template class BufferedParser : public Parser { +template class RequestParser : public Parser { public: - BufferedParser(RequestContextSharedPtr context) : context_{context} {}; + RequestParser(RequestContextSharedPtr context) : context_{context} {}; ParseResponse parse(const char*& buffer, uint64_t& remaining) override; protected: RequestContextSharedPtr context_; - BT buffer_; // underlying request-specific buffer + BT deserializer; // underlying request-specific deserializer }; template -ParseResponse BufferedParser::parse(const char*& buffer, uint64_t& remaining) { - context_->remaining_request_size_ -= buffer_.feed(buffer, remaining); - if (buffer_.ready()) { - // after a successful parse, there should be nothing left +ParseResponse RequestParser::parse(const char*& buffer, uint64_t& remaining) { + context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); + if (deserializer.ready()) { + // after a successful parse, there should be nothing left - we have consumed all the bytes ASSERT(0 == context_->remaining_request_size_); - RT request = buffer_.get(); + RT request = deserializer.get(); request.header() = context_->request_header_; ENVOY_LOG(trace, "parsed request {}: {}", *context_, request); MessageSharedPtr msg = std::make_shared(request); @@ -202,9 +222,9 @@ ParseResponse BufferedParser::parse(const char*& buffer, uint64_t& remai */ #define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ class REQUEST_TYPE##VERSION##Parser \ - : public BufferedParser { \ + : public RequestParser { \ public: \ - REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : BufferedParser{ctx} {}; \ + REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; \ }; /** @@ -235,10 +255,12 @@ class Request : public Message { */ size_t encode(Buffer::Instance& dst, EncodingContext& context) const { size_t written{0}; + // encode request header written += context.encode(request_header_.api_key_, dst); written += context.encode(request_header_.api_version_, dst); written += context.encode(request_header_.correlation_id_, dst); written += context.encode(request_header_.client_id_, dst); + // encode request-specific data written += encodeDetails(dst, context); return written; } @@ -247,7 +269,9 @@ class Request : public Message { * Pretty-prints given request into a stream */ std::ostream& print(std::ostream& os) const override final { + // write header os << request_header_ << " "; // not very pretty + // write request-specific data return printDetails(os); } diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h index 82e9b8bebf3db..d4a5689195898 100644 --- a/source/extensions/filters/network/kafka/messages/offset_commit.h +++ b/source/extensions/filters/network/kafka/messages/offset_commit.h @@ -13,6 +13,7 @@ namespace Kafka { /** * Holds the partition node (leaf) + * Supports all versions (some fields are not used in some versions) */ struct OffsetCommitPartition { const int32_t partition_; @@ -130,23 +131,160 @@ class OffsetCommitRequest : public Request { const NullableArray topics_; }; -// clang-format off -class OffsetCommitPartitionV0Buffer : public CompositeBuffer {}; -class OffsetCommitPartitionV0ArrayBuffer : public ArrayBuffer {}; -class OffsetCommitTopicV0Buffer : public CompositeBuffer {}; -class OffsetCommitTopicV0ArrayBuffer : public ArrayBuffer {}; +/** + * Deserializes bytes into OffsetCommitPartition (api version 0) + */ +class OffsetCommitPartitionV0Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += partition_.feed(buffer, remaining); + consumed += offset_.feed(buffer, remaining); + consumed += metadata_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return metadata_.ready(); } + OffsetCommitPartition get() const { return {partition_.get(), offset_.get(), metadata_.get()}; } + +protected: + Int32Deserializer partition_; + Int64Deserializer offset_; + NullableStringDeserializer metadata_; +}; + +/** + * Deserializes bytes into OffsetCommitPartition (api version 1) + */ +class OffsetCommitPartitionV1Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += partition_.feed(buffer, remaining); + consumed += offset_.feed(buffer, remaining); + consumed += timestamp_.feed(buffer, remaining); + consumed += metadata_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return metadata_.ready(); } + OffsetCommitPartition get() const { + return {partition_.get(), offset_.get(), timestamp_.get(), metadata_.get()}; + } + +protected: + Int32Deserializer partition_; + Int64Deserializer offset_; + Int64Deserializer timestamp_; + NullableStringDeserializer metadata_; +}; + +/** + * Deserializes array of OffsetCommitPartition-s v0 + */ +class OffsetCommitPartitionV0ArrayBuffer + : public ArrayDeserializer {}; + +/** + * Deserializes array of OffsetCommitPartition-s v1 + */ +class OffsetCommitPartitionV1ArrayBuffer + : public ArrayDeserializer {}; + +/** + * Deserializes bytes into OffsetCommitTopic v0 (which is composed of topic name + array of v0 + * partitions) + */ +class OffsetCommitTopicV0Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += topic_.feed(buffer, remaining); + consumed += partitions_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return partitions_.ready(); } + OffsetCommitTopic get() const { return {topic_.get(), partitions_.get()}; } + +protected: + StringDeserializer topic_; + OffsetCommitPartitionV0ArrayBuffer partitions_; +}; + +/** + * Deserializes bytes into OffsetCommitTopic v1 (which is composed of topic name + array of v1 + * partitions) + */ +class OffsetCommitTopicV1Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += topic_.feed(buffer, remaining); + consumed += partitions_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return partitions_.ready(); } + OffsetCommitTopic get() const { return {topic_.get(), partitions_.get()}; } -class OffsetCommitPartitionV1Buffer : public CompositeBuffer {}; -class OffsetCommitPartitionV1ArrayBuffer : public ArrayBuffer {}; -class OffsetCommitTopicV1Buffer : public CompositeBuffer {}; -class OffsetCommitTopicV1ArrayBuffer : public ArrayBuffer {}; +protected: + StringDeserializer topic_; + OffsetCommitPartitionV1ArrayBuffer partitions_; +}; -class OffsetCommitRequestV0Buffer : public CompositeBuffer {}; -class OffsetCommitRequestV1Buffer : public CompositeBuffer {}; +/** + * Deserializes array of OffsetCommitTopic-s v0 + */ +class OffsetCommitTopicV0ArrayBuffer + : public ArrayDeserializer {}; + +/** + * Deserializes array of OffsetCommitTopic-s v1 + */ +class OffsetCommitTopicV1ArrayBuffer + : public ArrayDeserializer {}; + +class OffsetCommitRequestV0Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += group_id_.feed(buffer, remaining); + consumed += topics_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return topics_.ready(); } + OffsetCommitRequest get() const { return {group_id_.get(), topics_.get()}; } + +protected: + StringDeserializer group_id_; + OffsetCommitTopicV0ArrayBuffer topics_; +}; + +class OffsetCommitRequestV1Buffer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += group_id_.feed(buffer, remaining); + consumed += generation_id_.feed(buffer, remaining); + consumed += member_id_.feed(buffer, remaining); + consumed += topics_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return topics_.ready(); } + OffsetCommitRequest get() const { + return {group_id_.get(), generation_id_.get(), member_id_.get(), topics_.get()}; + } + +protected: + StringDeserializer group_id_; + Int32Deserializer generation_id_; + StringDeserializer member_id_; + OffsetCommitTopicV1ArrayBuffer topics_; +}; + +/** + * Define Parsers that wrap the corresponding buffers + */ DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); -// clang-format on } // namespace Kafka } // namespace NetworkFilters diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 3ee2754ed5e80..ff3a0169118f6 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -1,6 +1,7 @@ #include "extensions/filters/network/kafka/request_codec.h" #include "common/buffer/buffer_impl.h" +#include "common/common/stack_array.h" #include "extensions/filters/network/kafka/kafka_protocol.h" @@ -11,8 +12,8 @@ namespace Kafka { void RequestDecoder::onData(Buffer::Instance& data) { uint64_t num_slices = data.getRawSlices(nullptr, 0); - Buffer::RawSlice slices[num_slices]; - data.getRawSlices(slices, num_slices); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + data.getRawSlices(slices.begin(), num_slices); for (const Buffer::RawSlice& slice : slices) { doParse(current_parser_, slice); } @@ -45,8 +46,14 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl } void RequestEncoder::encode(const Request& message) { + // XXX (adam.kotwasinski) theoretically this context could be generated inside Request::encode (as + // the requested knows the api_version), but the serialization design is still to be discussed + // (explicit classes vs vectors of pointers vs templates) EncodingContext encoder{message.apiVersion()}; Buffer::OwnedImpl data_buffer; + // TODO (adam.kotwasinski) precompute the size instead of using temporary + // also, when we have 'computeSize' method, then we can push encoding request's size into + // Request::encode int32_t data_len = encoder.encode(message, data_buffer); // encode data computing data length encoder.encode(data_len, output_); // encode data length into result output_.add(data_buffer); // copy data into result diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 10f3b4ba42d3a..bf8c5591c3a52 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -20,11 +20,12 @@ namespace NetworkFilters { namespace Kafka { /** - * Deserializer is a stateful entity that constructs a result from bytes provided + * Deserializer is a stateful entity that constructs a result of type T from bytes provided * It can be feed()-ed data until it is ready, filling the internal store * When ready(), it is safe to call get() to transform the internally stored bytes into result * Further feed()-ing should have no effect on a buffer (should return 0 and not move * buffer/remaining) + * @param T type of deserialized data */ template class Deserializer { public: @@ -53,11 +54,11 @@ template class Deserializer { /** * Generic integer deserializer (uses array of sizeof(T) bytes) - * The values are encoded in network byte order (big-endian). + * After all bytes are filled in, the value is converted from network byte-order and returned */ -template class IntBuffer : public Deserializer { +template class IntDeserializer : public Deserializer { public: - IntBuffer() : written_{0}, ready_(false){}; + IntDeserializer() : written_{0}, ready_(false){}; size_t feed(const char*& buffer, uint64_t& remaining) { const size_t available = std::min(sizeof(buf_) - written_, remaining); @@ -83,9 +84,9 @@ template class IntBuffer : public Deserializer { }; /** - * Deserializer for int8_t + * Integer deserializer for int8_t */ -class Int8Buffer : public IntBuffer { +class Int8Deserializer : public IntDeserializer { public: int8_t get() const { int8_t result; @@ -95,9 +96,9 @@ class Int8Buffer : public IntBuffer { }; /** - * Deserializer for int16_t + * Integer deserializer for int16_t */ -class Int16Buffer : public IntBuffer { +class Int16Deserializer : public IntDeserializer { public: int16_t get() const { int16_t result; @@ -107,9 +108,9 @@ class Int16Buffer : public IntBuffer { }; /** - * Deserializer for int32_t + * Integer deserializer for int32_t */ -class Int32Buffer : public IntBuffer { +class Int32Deserializer : public IntDeserializer { public: int32_t get() const { int32_t result; @@ -119,9 +120,9 @@ class Int32Buffer : public IntBuffer { }; /** - * Deserializer for uint32_t + * Integer deserializer for uint32_t */ -class UInt32Buffer : public IntBuffer { +class UInt32Deserializer : public IntDeserializer { public: uint32_t get() const { uint32_t result; @@ -131,9 +132,9 @@ class UInt32Buffer : public IntBuffer { }; /** - * Deserializer for uint64_t + * Integer deserializer for uint64_t */ -class Int64Buffer : public IntBuffer { +class Int64Deserializer : public IntDeserializer { public: int64_t get() const { int64_t result; @@ -143,7 +144,10 @@ class Int64Buffer : public IntBuffer { }; /** - * Deserializer of boolean value + * Deserializer for boolean values + * Uses a single int8 deserializers, and just checks != 0 + * impl note: could have been a subclass of IntDeserializer with a different get function, + * but it makes it harder to understand * * Boolean value is stored in a byte. * Values 0 and 1 are used to represent false and true respectively. @@ -160,17 +164,19 @@ class BoolBuffer : public Deserializer { bool get() const { return 0 != buffer_.get(); } private: - Int8Buffer buffer_; + Int8Deserializer buffer_; }; /** * Deserializer of string value + * First reads length (INT16) and then allocates the buffer of given length * + * From documentation: * First the length N is given as an int16_t. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * Length must not be negative. */ -class StringBuffer : public Deserializer { +class StringDeserializer : public Deserializer { public: size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); @@ -209,7 +215,7 @@ class StringBuffer : public Deserializer { std::string get() const { return std::string(data_buf_.begin(), data_buf_.end()); } private: - Int16Buffer length_buf_; + Int16Deserializer length_buf_; bool length_consumed_{false}; int16_t required_; @@ -220,12 +226,16 @@ class StringBuffer : public Deserializer { /** * Deserializer of nullable string value + * First reads length (INT16) and then allocates the buffer of given length + * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning + * null value) * + * From documentation: * For non-null strings, first the length N is given as an int16_t. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * A null value is encoded with length of -1 and there are no following bytes. */ -class NullableStringBuffer : public Deserializer { +class NullableStringDeserializer : public Deserializer { public: size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); @@ -279,7 +289,7 @@ class NullableStringBuffer : public Deserializer { private: constexpr static int16_t NULL_STRING_LENGTH{-1}; - Int16Buffer length_buf_; + Int16Deserializer length_buf_; bool length_consumed_{false}; int16_t required_; @@ -289,105 +299,21 @@ class NullableStringBuffer : public Deserializer { }; /** - * Composite deserializer - * Passes data to each of the underlying deserializers (deserializers that are already ready do not - * consume data, so it's safe) Is ready when the last deserializer is ready (which means all - * deserializers before it are ready too) Constructs the result using { buffer1_.get(), - * buffer2_.get() ... } - */ -template class CompositeBuffer; - -// XXX(adam.kotwasinski) I will get rid of this - -template class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer1_.ready(); } - RT get() const { return {buffer1_.get()}; } - -protected: - T1 buffer1_; -}; - -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer2_.ready(); } - RT get() const { return {buffer1_.get(), buffer2_.get()}; } - -protected: - T1 buffer1_; - T2 buffer2_; -}; - -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - consumed += buffer3_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer3_.ready(); } - RT get() const { return {buffer1_.get(), buffer2_.get(), buffer3_.get()}; } - -protected: - T1 buffer1_; - T2 buffer2_; - T3 buffer3_; -}; - -template -class CompositeBuffer : public Deserializer { -public: - CompositeBuffer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += buffer1_.feed(buffer, remaining); - consumed += buffer2_.feed(buffer, remaining); - consumed += buffer3_.feed(buffer, remaining); - consumed += buffer4_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return buffer4_.ready(); } - RT get() const { return {buffer1_.get(), buffer2_.get(), buffer3_.get(), buffer4_.get()}; } - -protected: - T1 buffer1_; - T2 buffer2_; - T3 buffer3_; - T4 buffer4_; -}; - -/** - * Deserializer for array of objects + * Deserializer for array of objects of the same type + * * First reads the length of the array, then initializes N underlying deserializers of type CT * After the last of N deserializers is ready, the results of each of them are gathered and put in a * vector * @param RT result type returned by deserializer CT * @param CT underlying deserializer type * - * Documentation: + * From documentation: * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T * follow. A null array is represented with a length of -1. */ -template class ArrayBuffer : public Deserializer> { +template +class ArrayDeserializer : public Deserializer> { public: size_t feed(const char*& buffer, uint64_t& remaining) { @@ -450,7 +376,7 @@ template class ArrayBuffer : public Deserializer children_; @@ -462,7 +388,7 @@ template class ArrayBuffer : public Deserializer class NullBuffer : public Deserializer { +template class NullDeserializer : public Deserializer { public: size_t feed(const char*&, uint64_t&) { return 0; } @@ -511,6 +437,7 @@ template inline size_t EncodingContext::encode(const T& arg, Buffer } /** + * Template overload for int8_t * Encode a single byte */ template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Instance& dst) { @@ -519,6 +446,7 @@ template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Ins } /** + * Template overload for int16_t, int32_t, uint32_t, int64_t * Encode a N-byte integer, converting to network byte-order */ #define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ @@ -534,6 +462,7 @@ ENCODE_NUMERIC_TYPE(uint32_t, htobe32); ENCODE_NUMERIC_TYPE(int64_t, htobe64); /** + * Template overload for bool * Encode boolean as a single byte */ template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Instance& dst) { @@ -543,6 +472,7 @@ template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Insta } /** + * Template overload for std::string * Encode string as INT16 length + N bytes */ template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer::Instance& dst) { @@ -553,6 +483,7 @@ template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer } /** + * Template overload for NullableString * Encode nullable string as INT16 length + N bytes (length = -1 for null) */ template <> @@ -566,6 +497,7 @@ inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instanc } /** + * Template overload for Bytes * Encode byte array as INT32 length + N bytes */ template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Instance& dst) { @@ -576,6 +508,7 @@ template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Inst } /** + * Template overload for NullableBytes * Encode nullable byte array as INT32 length + N bytes (length = -1 for null) */ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buffer::Instance& dst) { @@ -588,7 +521,8 @@ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buff } /** - * Encode nullable object array as INT32 length + N bytes (length = -1 for null) + * Encode nullable object array to T as INT32 length + N elements (length = -1 for null) + * Each element of type T then serializes itself on its own */ template size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& dst) { @@ -598,6 +532,7 @@ size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& ds size_t written{0}; for (const T& el : *arg) { // for each of array elements, resolve the correct method again + // elements could be primitives or complex types, so calling `el.encode()` won't work written += encode(el, dst); } return header_length + written; diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index 5fafce3765973..c073d2b0e6dbc 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -1,3 +1,5 @@ +#include "common/common/stack_array.h" + #include "extensions/filters/network/kafka/kafka_request.h" #include "extensions/filters/network/kafka/messages/offset_commit.h" @@ -66,8 +68,8 @@ class BufferBasedTest : public testing::Test { const char* getBytes() { uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); - Buffer::RawSlice slices[num_slices]; - buffer_.getRawSlices(slices, num_slices); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer_.getRawSlices(slices.begin(), num_slices); return reinterpret_cast((slices[0]).mem_); } diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index f41a4156e2a10..d3e06cbcfe606 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,3 +1,5 @@ +#include "common/common/stack_array.h" + #include "extensions/filters/network/kafka/serialization.h" #include "test/mocks/server/mocks.h" @@ -12,38 +14,37 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// freshly created buffers should not be ready -#define TEST_EmptyBufferShouldNotBeReady(BufferClass) \ - TEST(BufferClass, EmptyBufferShouldNotBeReady) { \ - const BufferClass testee{}; \ +/** + * Tests in this class are supposed to check whether serialization operations + * on Kafka-primitive types are behaving correctly + */ + +// freshly created deserializers should not be ready +#define TEST_EmptyDeserializerShouldNotBeReady(DeserializerClass) \ + TEST(DeserializerClass, EmptyBufferShouldNotBeReady) { \ + const DeserializerClass testee{}; \ ASSERT_EQ(testee.ready(), false); \ } -TEST_EmptyBufferShouldNotBeReady(Int8Buffer); -TEST_EmptyBufferShouldNotBeReady(Int16Buffer); -TEST_EmptyBufferShouldNotBeReady(Int32Buffer); -TEST_EmptyBufferShouldNotBeReady(UInt32Buffer); -TEST_EmptyBufferShouldNotBeReady(Int64Buffer); -TEST_EmptyBufferShouldNotBeReady(BoolBuffer); -TEST_EmptyBufferShouldNotBeReady(StringBuffer); -TEST_EmptyBufferShouldNotBeReady(NullableStringBuffer); -TEST(CompositeBuffer, EmptyBufferShouldNotBeReady) { +TEST_EmptyDeserializerShouldNotBeReady(Int8Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int16Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int32Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(UInt32Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(Int64Deserializer); +TEST_EmptyDeserializerShouldNotBeReady(BoolBuffer); +TEST_EmptyDeserializerShouldNotBeReady(StringDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(NullableStringDeserializer); +TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { // given - const CompositeBuffer testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} -TEST(ArrayBuffer, EmptyBufferShouldNotBeReady) { - // given - const ArrayBuffer testee{}; + const ArrayDeserializer testee{}; // when, then ASSERT_EQ(testee.ready(), false); } -// Null buffer is a special case, it's always ready and can provide results via 0-arg ctor -TEST(NullBuffer, EmptyBufferShouldBeReady) { +// Null deserializer is a special case, it's always ready and can provide results via 0-arg ctor +TEST(NullDeserializer, EmptyBufferShouldBeReady) { // given - const NullBuffer testee{}; + const NullDeserializer testee{}; // when, then ASSERT_EQ(testee.ready(), true); ASSERT_EQ(testee.get(), 0); @@ -51,16 +52,17 @@ TEST(NullBuffer, EmptyBufferShouldBeReady) { EncodingContext encoder{-1}; // context is not used when serializing primitive types +// helper function const char* getRawData(const Buffer::OwnedImpl& buffer) { uint64_t num_slices = buffer.getRawSlices(nullptr, 0); - Buffer::RawSlice slices[num_slices]; - buffer.getRawSlices(slices, num_slices); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); return reinterpret_cast((slices[0]).mem_); } // exactly what is says on the tin: // 1. serialize expected using Encoder -// 2. deserialize byte array using testee buffer +// 2. deserialize byte array using testee deserializer // 3. verify result = expected // 4. verify that data pointer moved correct amount // 5. feed testee more data @@ -101,7 +103,7 @@ void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { // does the same thing as the above test, // but instead of providing whole data at one, it provides it in N one-byte chunks -// this verifies if buffer keeps state properly +// this verifies if deserializer keeps state properly (no overwrites etc.) template void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { // given @@ -137,39 +139,40 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { ASSERT_EQ(remaining, 1024); } +// wrapper to run both tests template void serializeThenDeserializeAndCheckEquality(AT expected) { serializeThenDeserializeAndCheckEqualityInOneGo(expected); serializeThenDeserializeAndCheckEqualityWithChunks(expected); } // macroed out test for numeric buffers -#define TEST_BufferShouldDeserialize(BufferClass, DataClass, Value) \ +#define TEST_DeserializerShouldDeserialize(BufferClass, DataClass, Value) \ TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ /* given */ \ const DataClass value = Value; \ serializeThenDeserializeAndCheckEquality(value); \ } -TEST_BufferShouldDeserialize(Int8Buffer, int8_t, 42); -TEST_BufferShouldDeserialize(Int16Buffer, int16_t, 42); -TEST_BufferShouldDeserialize(Int32Buffer, int32_t, 42); -TEST_BufferShouldDeserialize(UInt32Buffer, uint32_t, 42); -TEST_BufferShouldDeserialize(Int64Buffer, int64_t, 42); -TEST_BufferShouldDeserialize(BoolBuffer, bool, true); +TEST_DeserializerShouldDeserialize(Int8Deserializer, int8_t, 42); +TEST_DeserializerShouldDeserialize(Int16Deserializer, int16_t, 42); +TEST_DeserializerShouldDeserialize(Int32Deserializer, int32_t, 42); +TEST_DeserializerShouldDeserialize(UInt32Deserializer, uint32_t, 42); +TEST_DeserializerShouldDeserialize(Int64Deserializer, int64_t, 42); +TEST_DeserializerShouldDeserialize(BoolBuffer, bool, true); -TEST(StringBuffer, ShouldDeserialize) { +TEST(StringDeserializer, ShouldDeserialize) { const std::string value = "sometext"; - serializeThenDeserializeAndCheckEquality(value); + serializeThenDeserializeAndCheckEquality(value); } -TEST(StringBuffer, ShouldDeserializeEmptyString) { +TEST(StringDeserializer, ShouldDeserializeEmptyString) { const std::string value = ""; - serializeThenDeserializeAndCheckEquality(value); + serializeThenDeserializeAndCheckEquality(value); } -TEST(StringBuffer, ShouldThrowOnInvalidLength) { +TEST(StringDeserializer, ShouldThrowOnInvalidLength) { // given - StringBuffer testee; + StringDeserializer testee; Buffer::OwnedImpl buffer; int16_t len = -1; // STRING accepts only >= 0 @@ -183,27 +186,27 @@ TEST(StringBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -TEST(NullableStringBuffer, ShouldDeserializeString) { +TEST(NullableStringDeserializer, ShouldDeserializeString) { // given const NullableString value{"sometext"}; - serializeThenDeserializeAndCheckEquality(value); + serializeThenDeserializeAndCheckEquality(value); } -TEST(NullableStringBuffer, ShouldDeserializeEmptyString) { +TEST(NullableStringDeserializer, ShouldDeserializeEmptyString) { // given const NullableString value{""}; - serializeThenDeserializeAndCheckEquality(value); + serializeThenDeserializeAndCheckEquality(value); } -TEST(NullableStringBuffer, ShouldDeserializeAbsentString) { +TEST(NullableStringDeserializer, ShouldDeserializeAbsentString) { // given const NullableString value = absl::nullopt; - serializeThenDeserializeAndCheckEquality(value); + serializeThenDeserializeAndCheckEquality(value); } -TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { +TEST(NullableStringDeserializer, ShouldThrowOnInvalidLength) { // given - NullableStringBuffer testee; + NullableStringDeserializer testee; Buffer::OwnedImpl buffer; int16_t len = -2; // -1 is OK for NULLABLE_STRING @@ -217,14 +220,15 @@ TEST(NullableStringBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -TEST(ArrayBuffer, ShouldConsumeCorrectAmountOfData) { +TEST(ArrayDeserializer, ShouldConsumeCorrectAmountOfData) { const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; - serializeThenDeserializeAndCheckEquality>(value); + serializeThenDeserializeAndCheckEquality>( + value); } -TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { +TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { // given - ArrayBuffer testee; + ArrayDeserializer testee; Buffer::OwnedImpl buffer; const int32_t len = -2; // -1 is OK for ARRAY @@ -238,34 +242,6 @@ TEST(ArrayBuffer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -struct CompositeBufferResult { - std::string field1_; - NullableArray field2_; - int16_t field3_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - return written; - } -}; - -bool operator==(const CompositeBufferResult& lhs, const CompositeBufferResult& rhs) { - return (lhs.field1_ == rhs.field1_) && (lhs.field2_ == rhs.field2_) && - (lhs.field3_ == rhs.field3_); -} - -typedef CompositeBuffer, - Int16Buffer> - TestCompositeBuffer; - -TEST(CompositeBuffer, ShouldDeserialize) { - const CompositeBufferResult expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234}; - serializeThenDeserializeAndCheckEquality(expected); -} - } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions From ac5f851c83f6db88735af2b7ce236c99f8d39781 Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Mon, 19 Nov 2018 13:34:11 +0000 Subject: [PATCH 07/29] Introduce composite deserializers for 2, 3, 4 delegates Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 5 +- .../filters/network/kafka/kafka_request.h | 32 +--- .../network/kafka/messages/offset_commit.h | 168 ++++-------------- .../network/kafka/serialization_composite.h | 133 ++++++++++++++ .../network/kafka/serialization_test.cc | 119 +++++++++++++ 5 files changed, 294 insertions(+), 163 deletions(-) create mode 100644 source/extensions/filters/network/kafka/serialization_composite.h diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 8e2c5420f6cc7..438c8a9c3f7bd 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -61,7 +61,10 @@ envoy_cc_library( envoy_cc_library( name = "serialization_lib", - hdrs = ["serialization.h"], + hdrs = [ + "serialization.h", + "serialization_composite.h", + ], deps = [ ":kafka_protocol_lib", "//include/envoy/buffer:buffer_interface", diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 24c0fc0a62801..283b0fb1aa335 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -10,6 +10,7 @@ #include "extensions/filters/network/kafka/kafka_protocol.h" #include "extensions/filters/network/kafka/parser.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/serialization_composite.h" namespace Envoy { namespace Extensions { @@ -131,31 +132,14 @@ class RequestStartParser : public Parser { Int32Deserializer request_length_; }; +// clang-format off /** * Deserializer that extracts request header * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class RequestHeaderDeserializer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += api_key_.feed(buffer, remaining); - consumed += api_version_.feed(buffer, remaining); - consumed += correlation_id_.feed(buffer, remaining); - consumed += client_id_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return client_id_.ready(); } - RequestHeader get() const { - return {api_key_.get(), api_version_.get(), correlation_id_.get(), client_id_.get()}; - } - -protected: - Int16Deserializer api_key_; - Int16Deserializer api_version_; - Int32Deserializer correlation_id_; - NullableStringDeserializer client_id_; -}; +class RequestHeaderDeserializer + : public CompositeDeserializerWith4Delegates {}; +// clang-format on /** * Parser responsible for computing request header and updating the context with data resolved @@ -216,13 +200,13 @@ ParseResponse RequestParser::parse(const char*& buffer, uint64_t& remain } /** - * Macro defining RequestParser that uses the underlying Buffer + * Macro defining RequestParser that uses the underlying Deserializer * Aware of versioning - * Names of Buffers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names + * Names of Deserializers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names */ #define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ class REQUEST_TYPE##VERSION##Parser \ - : public RequestParser { \ + : public RequestParser { \ public: \ REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; \ }; diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h index d4a5689195898..9016823bef06d 100644 --- a/source/extensions/filters/network/kafka/messages/offset_commit.h +++ b/source/extensions/filters/network/kafka/messages/offset_commit.h @@ -131,156 +131,48 @@ class OffsetCommitRequest : public Request { const NullableArray topics_; }; -/** - * Deserializes bytes into OffsetCommitPartition (api version 0) - */ -class OffsetCommitPartitionV0Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += partition_.feed(buffer, remaining); - consumed += offset_.feed(buffer, remaining); - consumed += metadata_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return metadata_.ready(); } - OffsetCommitPartition get() const { return {partition_.get(), offset_.get(), metadata_.get()}; } +// clang-format off -protected: - Int32Deserializer partition_; - Int64Deserializer offset_; - NullableStringDeserializer metadata_; -}; - -/** - * Deserializes bytes into OffsetCommitPartition (api version 1) - */ -class OffsetCommitPartitionV1Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += partition_.feed(buffer, remaining); - consumed += offset_.feed(buffer, remaining); - consumed += timestamp_.feed(buffer, remaining); - consumed += metadata_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return metadata_.ready(); } - OffsetCommitPartition get() const { - return {partition_.get(), offset_.get(), timestamp_.get(), metadata_.get()}; - } - -protected: - Int32Deserializer partition_; - Int64Deserializer offset_; - Int64Deserializer timestamp_; - NullableStringDeserializer metadata_; -}; +// api version 0 -/** - * Deserializes array of OffsetCommitPartition-s v0 - */ +// Deserializes bytes into OffsetCommitPartition (api version 0): partition, offset, metadata +class OffsetCommitPartitionV0Buffer + : public CompositeDeserializerWith3Delegates {}; +// Deserializes array of OffsetCommitPartition-s v0 class OffsetCommitPartitionV0ArrayBuffer : public ArrayDeserializer {}; - -/** - * Deserializes array of OffsetCommitPartition-s v1 - */ -class OffsetCommitPartitionV1ArrayBuffer - : public ArrayDeserializer {}; - -/** - * Deserializes bytes into OffsetCommitTopic v0 (which is composed of topic name + array of v0 - * partitions) - */ -class OffsetCommitTopicV0Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += topic_.feed(buffer, remaining); - consumed += partitions_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return partitions_.ready(); } - OffsetCommitTopic get() const { return {topic_.get(), partitions_.get()}; } - -protected: - StringDeserializer topic_; - OffsetCommitPartitionV0ArrayBuffer partitions_; -}; - -/** - * Deserializes bytes into OffsetCommitTopic v1 (which is composed of topic name + array of v1 - * partitions) - */ -class OffsetCommitTopicV1Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += topic_.feed(buffer, remaining); - consumed += partitions_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return partitions_.ready(); } - OffsetCommitTopic get() const { return {topic_.get(), partitions_.get()}; } - -protected: - StringDeserializer topic_; - OffsetCommitPartitionV1ArrayBuffer partitions_; -}; - -/** - * Deserializes array of OffsetCommitTopic-s v0 - */ +// Deserializes bytes into OffsetCommitTopic (api version 0): topic name, partitions (v0) +class OffsetCommitTopicV0Buffer + : public CompositeDeserializerWith2Delegates {}; +// Deserializes array of OffsetCommitTopic-s v0 class OffsetCommitTopicV0ArrayBuffer : public ArrayDeserializer {}; +// Deserializes bytes into OffsetCommitRequest (api version 0): group_id, topics (v0) +class OffsetCommitRequestV0Deserializer + : public CompositeDeserializerWith2Delegates {}; -/** - * Deserializes array of OffsetCommitTopic-s v1 - */ +// api version 1 + +// Deserializes bytes into OffsetCommitPartition (api version 1): partition, offset, timestamp, metadata +class OffsetCommitPartitionV1Buffer + : public CompositeDeserializerWith4Delegates {}; +// Deserializes array of OffsetCommitPartition-s v1 +class OffsetCommitPartitionV1ArrayBuffer + : public ArrayDeserializer {}; +// Deserializes bytes into OffsetCommitTopic (api version 1): topic name, partitions (v1) +class OffsetCommitTopicV1Buffer + : public CompositeDeserializerWith2Delegates {}; +// Deserializes array of OffsetCommitTopic-s v1 class OffsetCommitTopicV1ArrayBuffer : public ArrayDeserializer {}; +// Deserializes bytes into OffsetCommitRequest (api version 1): group_id, generation_id, member_id, topics (v1) +class OffsetCommitRequestV1Deserializer + : public CompositeDeserializerWith4Delegates {}; -class OffsetCommitRequestV0Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += group_id_.feed(buffer, remaining); - consumed += topics_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return topics_.ready(); } - OffsetCommitRequest get() const { return {group_id_.get(), topics_.get()}; } - -protected: - StringDeserializer group_id_; - OffsetCommitTopicV0ArrayBuffer topics_; -}; - -class OffsetCommitRequestV1Buffer : public Deserializer { -public: - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += group_id_.feed(buffer, remaining); - consumed += generation_id_.feed(buffer, remaining); - consumed += member_id_.feed(buffer, remaining); - consumed += topics_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return topics_.ready(); } - OffsetCommitRequest get() const { - return {group_id_.get(), generation_id_.get(), member_id_.get(), topics_.get()}; - } - -protected: - StringDeserializer group_id_; - Int32Deserializer generation_id_; - StringDeserializer member_id_; - OffsetCommitTopicV1ArrayBuffer topics_; -}; +// clang-format on /** - * Define Parsers that wrap the corresponding buffers + * Define Parsers that wrap the corresponding deserializers */ DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h new file mode 100644 index 0000000000000..6147dabc2dab0 --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization_composite.h @@ -0,0 +1,133 @@ +#pragma once + +#include +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * This header contains only composite deserializers + * The basic design is composite deserializer creating delegates T1..Tn + * Result of type RT is constructed by getting results of each of delegates + */ + +/** + * Composite deserializer that uses 2 deserializers + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * + * @param RT type of deserialized data + * @param T1 1st deserializer (result used as 1st argument of RT's ctor) + * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) + */ +template +class CompositeDeserializerWith2Delegates : public Deserializer { +public: + CompositeDeserializerWith2Delegates(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return delegate2_.ready(); } + RT get() const { return {delegate1_.get(), delegate2_.get()}; } + +protected: + T1 delegate1_; + T2 delegate2_; +}; + +/** + * Composite deserializer that uses 3 deserializers + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * + * @param RT type of deserialized data + * @param T1 1st deserializer (result used as 1st argument of RT's ctor) + * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) + * @param T3 3rd deserializer (result used as 3rd argument of RT's ctor) + */ +template +class CompositeDeserializerWith3Delegates : public Deserializer { +public: + CompositeDeserializerWith3Delegates(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return delegate3_.ready(); } + RT get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } + +protected: + T1 delegate1_; + T2 delegate2_; + T3 delegate3_; +}; + +/** + * Composite deserializer that uses 4 deserializers + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * + * @param RT type of deserialized data + * @param T1 1st deserializer (result used as 1st argument of RT's ctor) + * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) + * @param T3 3rd deserializer (result used as 3rd argument of RT's ctor) + * @param T4 4th deserializer (result used as 4th argument of RT's ctor) + */ +template +class CompositeDeserializerWith4Delegates : public Deserializer { +public: + CompositeDeserializerWith4Delegates(){}; + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + return consumed; + } + bool ready() const { return delegate4_.ready(); } + RT get() const { + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; + } + +protected: + T1 delegate1_; + T2 delegate2_; + T3 delegate3_; + T4 delegate4_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index d3e06cbcfe606..9555b3e639c9b 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,6 +1,7 @@ #include "common/common/stack_array.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/serialization_composite.h" #include "test/mocks/server/mocks.h" @@ -34,6 +35,38 @@ TEST_EmptyDeserializerShouldNotBeReady(Int64Deserializer); TEST_EmptyDeserializerShouldNotBeReady(BoolBuffer); TEST_EmptyDeserializerShouldNotBeReady(StringDeserializer); TEST_EmptyDeserializerShouldNotBeReady(NullableStringDeserializer); +TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { + // given + struct CompositeResult { + CompositeResult(int8_t, int16_t){}; + }; + const CompositeDeserializerWith2Delegates + testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} +TEST(CompositeDeserializerWith3Delegates, EmptyBufferShouldNotBeReady) { + // given + struct CompositeResult { + CompositeResult(int8_t, int16_t, int32_t){}; + }; + const CompositeDeserializerWith3Delegates + testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} +TEST(CompositeDeserializerWith4Delegates, EmptyBufferShouldNotBeReady) { + // given + struct CompositeResult { + CompositeResult(int8_t, int16_t, int32_t, std::string){}; + }; + const CompositeDeserializerWith4Delegates + testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { // given const ArrayDeserializer testee{}; @@ -242,6 +275,92 @@ TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } +// tests for composite deserializers + +struct CompositeResultWith2Fields { + std::string field1_; + NullableArray field2_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + return written; + } + + bool operator==(const CompositeResultWith2Fields& rhs) const { + return (field1_ == rhs.field1_) && (field2_ == rhs.field2_); + } +}; + +struct CompositeResultWith3Fields { + std::string field1_; + NullableArray field2_; + int16_t field3_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + return written; + } + + bool operator==(const CompositeResultWith3Fields& rhs) const { + return (field1_ == rhs.field1_) && (field2_ == rhs.field2_) && (field3_ == rhs.field3_); + } +}; + +struct CompositeResultWith4Fields { + std::string field1_; + NullableArray field2_; + int16_t field3_; + std::string field4_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + return written; + } + + bool operator==(const CompositeResultWith4Fields& rhs) const { + return (field1_ == rhs.field1_) && (field2_ == rhs.field2_) && (field3_ == rhs.field3_) && + (field4_ == rhs.field4_); + } +}; + +typedef CompositeDeserializerWith2Delegates> + TestCompositeDeserializer2; + +typedef CompositeDeserializerWith3Delegates, + Int16Deserializer> + TestCompositeDeserializer3; + +typedef CompositeDeserializerWith4Delegates, + Int16Deserializer, StringDeserializer> + TestCompositeDeserializer4; + +TEST(CompositeDeserializerWith2Delegates, ShouldDeserialize) { + const CompositeResultWith2Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}}; + serializeThenDeserializeAndCheckEquality(expected); +} + +TEST(CompositeDeserializerWith3Delegates, ShouldDeserialize) { + const CompositeResultWith3Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234}; + serializeThenDeserializeAndCheckEquality(expected); +} + +TEST(CompositeDeserializerWith4Delegates, ShouldDeserialize) { + const CompositeResultWith4Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234, "aaa"}; + serializeThenDeserializeAndCheckEquality(expected); +} + } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions From 8cb67ee512abc9e2116f08d714e6c5a1669dbdca Mon Sep 17 00:00:00 2001 From: "adam.kotwasinski" Date: Tue, 20 Nov 2018 14:13:12 +0000 Subject: [PATCH 08/29] Review fixes: - some renames - more comments - missing include Signed-off-by: Adam Kotwasinski --- .../extensions/filters/network/kafka/codec.h | 2 + .../filters/network/kafka/debug_helpers.h | 2 + .../filters/network/kafka/kafka_protocol.h | 2 +- .../filters/network/kafka/kafka_request.cc | 22 +++--- .../filters/network/kafka/kafka_request.h | 71 ++++++++++++------- .../filters/network/kafka/message.h | 1 + .../network/kafka/messages/offset_commit.h | 7 +- .../filters/network/kafka/request_codec.cc | 11 +++ .../filters/network/kafka/request_codec.h | 21 ++++++ .../network/kafka/kafka_request_test.cc | 4 +- 10 files changed, 100 insertions(+), 43 deletions(-) diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index 0e255b509bc21..705c87b0f7d58 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -8,6 +8,8 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +// abstract codecs for requests and responses + /** * Kafka message decoder * @tparam MT message type (Kafka request or Kafka response) diff --git a/source/extensions/filters/network/kafka/debug_helpers.h b/source/extensions/filters/network/kafka/debug_helpers.h index 17224c3baa54a..0a2f1c6bdf82a 100644 --- a/source/extensions/filters/network/kafka/debug_helpers.h +++ b/source/extensions/filters/network/kafka/debug_helpers.h @@ -12,6 +12,7 @@ namespace Kafka { // functions present in this header are used by request / response objects to print their fields // nicely +// prints out std::vector template std::ostream& operator<<(std::ostream& os, const std::vector& arg) { os << "["; for (auto iter = arg.begin(); iter != arg.end(); iter++) { @@ -24,6 +25,7 @@ template std::ostream& operator<<(std::ostream& os, const std::vect return os; } +// prints out absl::optional template std::ostream& operator<<(std::ostream& os, const absl::optional& arg) { if (arg.has_value()) { os << *arg; diff --git a/source/extensions/filters/network/kafka/kafka_protocol.h b/source/extensions/filters/network/kafka/kafka_protocol.h index ca785f14c1500..ecbe73aec1287 100644 --- a/source/extensions/filters/network/kafka/kafka_protocol.h +++ b/source/extensions/filters/network/kafka/kafka_protocol.h @@ -12,7 +12,7 @@ namespace NetworkFilters { namespace Kafka { /** - * Kafka request type identifier + * Kafka request type identifier (int16_t value present in header of every request) * @see http://kafka.apache.org/protocol.html#protocol_api_keys */ enum RequestType : int16_t { diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc index 7a26f8bf8835b..529822be8a069 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -10,9 +10,9 @@ namespace NetworkFilters { namespace Kafka { // helper function that generates a map from specs looking like { api_key, api_versions... } -GeneratorMap computeGeneratorMap(const GeneratorMap& original, - const std::vector specs) { - GeneratorMap result{original}; +ParserGenerators computeGeneratorMap(const ParserGenerators& original, + const std::vector specs) { + ParserGenerators result{original}; for (auto& spec : specs) { auto& generators = result[spec.api_key_]; for (int16_t api_version : spec.api_versions_) { @@ -23,13 +23,16 @@ GeneratorMap computeGeneratorMap(const GeneratorMap& original, return result; } -RequestParserResolver::RequestParserResolver(const std::vector arg) - : generators_{computeGeneratorMap({}, arg)} {}; +RequestParserResolver::RequestParserResolver(const std::vector specs) + : generators_{computeGeneratorMap({}, specs)} {}; RequestParserResolver::RequestParserResolver(const RequestParserResolver& original, - const std::vector arg) - : generators_{computeGeneratorMap(original.generators_, arg)} {}; + const std::vector specs) + : generators_{computeGeneratorMap(original.generators_, specs)} {}; +// helper macro binding request type & api versions to Deserializers +// the rendered function will create a new instance of (REQUEST)RequestV(Version)Parser +// e.g. OffsetCommitRequestV0Parser #define PARSER_SPEC(REQUEST_NAME, PARSER_VERSION, ...) \ ParserSpec { \ RequestType::REQUEST_NAME, {__VA_ARGS__}, [](RequestContextSharedPtr arg) -> ParserSharedPtr { \ @@ -56,7 +59,8 @@ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api if (generators_.end() == api_versions_ptr) { return std::make_shared(context); } - const std::unordered_map& api_versions = api_versions_ptr->second; + const std::unordered_map& api_versions = + api_versions_ptr->second; // api_version const auto generator_ptr = api_versions.find(api_version); @@ -65,7 +69,7 @@ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api } // found matching parser generator, create parser - const GeneratorFunction generator = generator_ptr->second; + const ParserGeneratorFunction generator = generator_ptr->second; return generator(context); } diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 283b0fb1aa335..c6003293dcf1d 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -57,21 +57,22 @@ typedef std::shared_ptr RequestContextSharedPtr; /** * Function generating a parser with given context */ -typedef std::function GeneratorFunction; +typedef std::function ParserGeneratorFunction; /** - * Structure responsible for mapping [api_key, api_version] -> GeneratorFunction + * Structure responsible for mapping [api_key, api_version] -> ParserGeneratorFunction */ -typedef std::unordered_map> GeneratorMap; +typedef std::unordered_map> + ParserGenerators; /** - * Trivial structure specifying which generator function should be used + * Trivial structure specifying which parser generator function should be used * for which api_key & api_version */ struct ParserSpec { const int16_t api_key_; const std::vector api_versions_; - const GeneratorFunction generator_; + const ParserGeneratorFunction generator_; }; /** @@ -81,8 +82,17 @@ struct ParserSpec { */ class RequestParserResolver { public: - RequestParserResolver(const std::vector arg); - RequestParserResolver(const RequestParserResolver& original, const std::vector arg); + /** + * Creates a resolver that uses generator functions provided by given specifications + */ + RequestParserResolver(const std::vector specs); + + /** + * Creates a resolver that uses generator functions provided by original resolver and then + * expanded by specifications + */ + RequestParserResolver(const RequestParserResolver& original, const std::vector specs); + virtual ~RequestParserResolver() = default; /** @@ -106,7 +116,7 @@ class RequestParserResolver { static const RequestParserResolver KAFKA_1_0; private: - GeneratorMap generators_; + ParserGenerators generators_; }; /** @@ -166,7 +176,7 @@ class RequestHeaderParser : public Parser { }; /** - * Buffered parser uses a single deserializer to construct a response + * Request parser uses a single deserializer to construct a request object * This parser is responsible for consuming request-specific data (e.g. topic names) and always * returns a parsed message * @param RT request class @@ -175,34 +185,41 @@ class RequestHeaderParser : public Parser { */ template class RequestParser : public Parser { public: + /** + * Create a parser with given context + * @param context parse context containing request header + */ RequestParser(RequestContextSharedPtr context) : context_{context} {}; - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + /** + * Consume enough data to fill in deserializer and receive the parsed request + * Fill in request's header with data stored in context + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) { + context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); + if (deserializer.ready()) { + // after a successful parse, there should be nothing left - we have consumed all the bytes + ASSERT(0 == context_->remaining_request_size_); + RT request = deserializer.get(); + request.header() = context_->request_header_; + ENVOY_LOG(trace, "parsed request {}: {}", *context_, request); + MessageSharedPtr msg = std::make_shared(request); + return ParseResponse::parsedMessage(msg); + } else { + return ParseResponse::stillWaiting(); + } + } protected: RequestContextSharedPtr context_; BT deserializer; // underlying request-specific deserializer }; -template -ParseResponse RequestParser::parse(const char*& buffer, uint64_t& remaining) { - context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); - if (deserializer.ready()) { - // after a successful parse, there should be nothing left - we have consumed all the bytes - ASSERT(0 == context_->remaining_request_size_); - RT request = deserializer.get(); - request.header() = context_->request_header_; - ENVOY_LOG(trace, "parsed request {}: {}", *context_, request); - MessageSharedPtr msg = std::make_shared(request); - return ParseResponse::parsedMessage(msg); - } else { - return ParseResponse::stillWaiting(); - } -} - /** - * Macro defining RequestParser that uses the underlying Deserializer + * Helper macro defining RequestParser that uses the underlying Deserializer * Aware of versioning * Names of Deserializers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names + * Renders class named (Request)(Version)Parser e.g. OffsetCommitRequestV0Parser */ #define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ class REQUEST_TYPE##VERSION##Parser \ diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h index a929698844563..b54c3743bbf07 100644 --- a/source/extensions/filters/network/kafka/message.h +++ b/source/extensions/filters/network/kafka/message.h @@ -1,5 +1,6 @@ #pragma once +#include #include #include "envoy/common/pure.h" diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h index 9016823bef06d..e1b6288468f29 100644 --- a/source/extensions/filters/network/kafka/messages/offset_commit.h +++ b/source/extensions/filters/network/kafka/messages/offset_commit.h @@ -12,8 +12,7 @@ namespace Kafka { */ /** - * Holds the partition node (leaf) - * Supports all versions (some fields are not used in some versions) + * Holds the partition data: partition, offset, timestamp, metadata */ struct OffsetCommitPartition { const int32_t partition_; @@ -53,7 +52,7 @@ struct OffsetCommitPartition { }; /** - * Holds the topic node (contains multiple partitions) + * Holds the topic data: topic name and partitions in that topic */ struct OffsetCommitTopic { const std::string topic_; @@ -76,7 +75,7 @@ struct OffsetCommitTopic { }; /** - * Holds the request (contains multiple topics) + * Holds the request: group id, generation id, member id, retention time, topics */ class OffsetCommitRequest : public Request { public: diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index ff3a0169118f6..04468f2978395 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -10,6 +10,7 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +// convert buffer to slices and pass them to `doParse` void RequestDecoder::onData(Buffer::Instance& data) { uint64_t num_slices = data.getRawSlices(nullptr, 0); STACK_ARRAY(slices, Buffer::RawSlice, num_slices); @@ -19,6 +20,16 @@ void RequestDecoder::onData(Buffer::Instance& data) { } } +/** + * Main parse loop: + * - forward data to current parser + * - receive parser response: + * -- if still waiting, do nothing + * -- if next parser, replace parser, and keep feeding, if still have data + * -- if parser message: + * --- notify callbacks + * --- replace parser with new start parser, as we are going to parse another request + */ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice) { const char* buffer = reinterpret_cast(slice.mem_); uint64_t remaining = slice.len_; diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index 8078a87e16d65..77e40fccdd856 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -31,14 +31,28 @@ typedef std::shared_ptr RequestCallbackSharedPtr; * * This decoder uses chain of parsers to parse fragments of a request * Each parser along the line returns the fully parsed message or the next parser + * Stores parse state (have `onData` invoked multiple times for messages that are larger than single + * buffer) */ class RequestDecoder : public MessageDecoder, public Logger::Loggable { public: + /** + * Creates a decoder that can decode requests specified by RequestParserResolver, notifying + * callbacks on successful decoding + * @param parserResolver supported parser resolver + * @param callbacks callbacks to be invoked (in order) + */ RequestDecoder(const RequestParserResolver parserResolver, const std::vector callbacks) : parser_resolver_{parserResolver}, callbacks_{callbacks}, current_parser_{new RequestStartParser(parser_resolver_)} {}; + /** + * Consumes all data present in a buffer + * If a request can be successfully parsed, then callbacks get notified with parsed request + * Updates decoder state + * impl note: similar to redis codec, which also keeps state + */ void onData(Buffer::Instance& data); private: @@ -55,7 +69,14 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable { public: + /** + * Wraps buffer with encoder + */ RequestEncoder(Buffer::Instance& output) : output_(output) {} + + /** + * Encodes request into wrapped buffer + */ void encode(const Request& message) override; private: diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index c073d2b0e6dbc..268c6eda29678 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -31,7 +31,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestTypeIsNotRegistered) { TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) { // given - GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { + ParserGeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { return std::make_shared(arg); }; RequestParserResolver testee{{{0, {0, 1}, generator}}}; @@ -48,7 +48,7 @@ TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { // given - GeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { + ParserGeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { return std::make_shared(arg); }; RequestParserResolver testee{{{0, {0, 1, 2, 3}, generator}}}; From 7e3f54ca13265226b08b8a7a48ae8065af11558e Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 27 Nov 2018 13:48:30 +0000 Subject: [PATCH 09/29] Apply review fixes: - documentation - remove unused code / imports - add Bytes/NullableBytesDeserializer, remove NullDeserializer - remove print capability from Message/Request Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 1 - .../extensions/filters/network/kafka/codec.h | 22 ++- .../filters/network/kafka/debug_helpers.h | 41 ---- .../filters/network/kafka/kafka_protocol.h | 6 - .../filters/network/kafka/kafka_request.h | 54 ++---- .../filters/network/kafka/kafka_types.h | 17 +- .../filters/network/kafka/message.h | 5 - .../network/kafka/messages/offset_commit.h | 45 ++--- .../filters/network/kafka/request_codec.cc | 1 - .../filters/network/kafka/request_codec.h | 10 +- .../filters/network/kafka/serialization.h | 180 ++++++++++++++---- .../network/kafka/serialization_composite.h | 72 +++---- .../filters/network/well_known_names.h | 2 - .../network/kafka/serialization_test.cc | 74 +++++-- 14 files changed, 303 insertions(+), 227 deletions(-) delete mode 100644 source/extensions/filters/network/kafka/debug_helpers.h diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 438c8a9c3f7bd..cbd1bc719cf54 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -28,7 +28,6 @@ envoy_cc_library( name = "kafka_request_lib", srcs = ["kafka_request.cc"], hdrs = [ - "debug_helpers.h", "kafka_request.h", "messages/offset_commit.h", ], diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index 705c87b0f7d58..a68d798ac400e 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -8,26 +8,34 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// abstract codecs for requests and responses - /** * Kafka message decoder - * @tparam MT message type (Kafka request or Kafka response) + * @tparam MessageType message type (Kafka request or Kafka response) */ -template class MessageDecoder { +template class MessageDecoder { public: virtual ~MessageDecoder() = default; + + /** + * Processes given buffer attempting to decode messages of type MessageType container within + * @param data buffer instance + */ virtual void onData(Buffer::Instance& data) PURE; }; /** * Kafka message decoder - * @tparam MT message type (Kafka request or Kafka response) + * @tparam MessageType message type (Kafka request or Kafka response) */ -template class MessageEncoder { +template class MessageEncoder { public: virtual ~MessageEncoder() = default; - virtual void encode(const MT& message) PURE; + + /** + * Encodes given message + * @param message message to be encoded + */ + virtual void encode(const MessageType& message) PURE; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/debug_helpers.h b/source/extensions/filters/network/kafka/debug_helpers.h deleted file mode 100644 index 0a2f1c6bdf82a..0000000000000 --- a/source/extensions/filters/network/kafka/debug_helpers.h +++ /dev/null @@ -1,41 +0,0 @@ -#pragma once - -#include - -#include "absl/types/optional.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -// functions present in this header are used by request / response objects to print their fields -// nicely - -// prints out std::vector -template std::ostream& operator<<(std::ostream& os, const std::vector& arg) { - os << "["; - for (auto iter = arg.begin(); iter != arg.end(); iter++) { - if (iter != arg.begin()) { - os << ", "; - } - os << *iter; - } - os << "]"; - return os; -} - -// prints out absl::optional -template std::ostream& operator<<(std::ostream& os, const absl::optional& arg) { - if (arg.has_value()) { - os << *arg; - } else { - os << ""; - } - return os; -} - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_protocol.h b/source/extensions/filters/network/kafka/kafka_protocol.h index ecbe73aec1287..3d1f07498b37c 100644 --- a/source/extensions/filters/network/kafka/kafka_protocol.h +++ b/source/extensions/filters/network/kafka/kafka_protocol.h @@ -1,11 +1,5 @@ #pragma once -#include - -#include "common/common/macros.h" - -#include "extensions/filters/network/kafka/kafka_types.h" - namespace Envoy { namespace Extensions { namespace NetworkFilters { diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index c6003293dcf1d..45b3f7d64dd74 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -6,7 +6,6 @@ #include "common/common/assert.h" -#include "extensions/filters/network/kafka/debug_helpers.h" #include "extensions/filters/network/kafka/kafka_protocol.h" #include "extensions/filters/network/kafka/parser.h" #include "extensions/filters/network/kafka/serialization.h" @@ -31,12 +30,6 @@ struct RequestHeader { return api_key_ == rhs.api_key_ && api_version_ == rhs.api_version_ && correlation_id_ == rhs.correlation_id_ && client_id_ == rhs.client_id_; }; - - friend std::ostream& operator<<(std::ostream& os, const RequestHeader& arg) { - return os << "{api_key=" << arg.api_key_ << ", api_version=" << arg.api_version_ - << ", correlation_id=" << arg.correlation_id_ << ", client_id=" << arg.client_id_ - << "}"; - }; }; /** @@ -45,11 +38,6 @@ struct RequestHeader { struct RequestContext { int32_t remaining_request_size_{0}; RequestHeader request_header_{}; - - friend std::ostream& operator<<(std::ostream& os, const RequestContext& arg) { - return os << "{header=" << arg.request_header_ << ", remaining=" << arg.remaining_request_size_ - << "}"; - } }; typedef std::shared_ptr RequestContextSharedPtr; @@ -132,7 +120,7 @@ class RequestStartParser : public Parser { * Consumes INT32 bytes as request length and updates the context with that value * @return RequestHeaderParser instance to process request header */ - ParseResponse parse(const char*& buffer, uint64_t& remaining); + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -142,14 +130,14 @@ class RequestStartParser : public Parser { Int32Deserializer request_length_; }; -// clang-format off /** - * Deserializer that extracts request header + * Deserializer that extracts request header (4 fields) * @see http://kafka.apache.org/protocol.html#protocol_messages */ class RequestHeaderDeserializer - : public CompositeDeserializerWith4Delegates {}; -// clang-format on + : public CompositeDeserializerWith4Delegates {}; /** * Parser responsible for computing request header and updating the context with data resolved @@ -165,7 +153,7 @@ class RequestHeaderParser : public Parser { * Uses data provided to compute request header * @return Parser instance responsible for processing rest of the message */ - ParseResponse parse(const char*& buffer, uint64_t& remaining); + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -183,7 +171,7 @@ class RequestHeaderParser : public Parser { * @param BT deserializer type corresponding to request class (should be subclass of * Deserializer) */ -template class RequestParser : public Parser { +template class RequestParser : public Parser { public: /** * Create a parser with given context @@ -195,15 +183,14 @@ template class RequestParser : public Parser { * Consume enough data to fill in deserializer and receive the parsed request * Fill in request's header with data stored in context */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) { + ParseResponse parse(const char*& buffer, uint64_t& remaining) override { context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); if (deserializer.ready()) { // after a successful parse, there should be nothing left - we have consumed all the bytes ASSERT(0 == context_->remaining_request_size_); - RT request = deserializer.get(); + RequestType request = deserializer.get(); request.header() = context_->request_header_; - ENVOY_LOG(trace, "parsed request {}: {}", *context_, request); - MessageSharedPtr msg = std::make_shared(request); + MessageSharedPtr msg = std::make_shared(request); return ParseResponse::parsedMessage(msg); } else { return ParseResponse::stillWaiting(); @@ -212,7 +199,7 @@ template class RequestParser : public Parser { protected: RequestContextSharedPtr context_; - BT deserializer; // underlying request-specific deserializer + DeserializerType deserializer; // underlying request-specific deserializer }; /** @@ -266,27 +253,12 @@ class Request : public Message { return written; } - /** - * Pretty-prints given request into a stream - */ - std::ostream& print(std::ostream& os) const override final { - // write header - os << request_header_ << " "; // not very pretty - // write request-specific data - return printDetails(os); - } - protected: /** * Encodes request-specific data into a buffer */ virtual size_t encodeDetails(Buffer::Instance&, EncodingContext&) const PURE; - /** - * Prints request-specific data into a stream - */ - virtual std::ostream& printDetails(std::ostream&) const PURE; - RequestHeader request_header_; }; @@ -305,10 +277,6 @@ class UnknownRequest : public Request { size_t encodeDetails(Buffer::Instance&, EncodingContext&) const override { throw EnvoyException("cannot serialize unknown request"); } - - std::ostream& printDetails(std::ostream& out) const override { - return out << "{unknown request}"; - } }; /** diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h index 0fe1d9591b548..4d7a6a09cd364 100644 --- a/source/extensions/filters/network/kafka/kafka_types.h +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -11,31 +11,22 @@ namespace NetworkFilters { namespace Kafka { /** - * Represents a sequence of characters or null. For non-null strings, first the length N is given as - * an INT16. Then N bytes follow which are the UTF-8 encoding of the character sequence. A null - * value is encoded with length of -1 and there are no following bytes. + * Nullable string used by Kafka */ typedef absl::optional NullableString; /** - * Represents a raw sequence of bytes. - * First the length N is given as an INT32. Then N bytes follow. + * Bytes array used by Kafka */ typedef std::vector Bytes; /** - * Represents a raw sequence of bytes or null. For non-null values, first the length N is given as - * an INT32. Then N bytes follow. A null value is encoded with length of -1 and there are no - * following bytes. + * Nullable bytes array used by Kafka */ typedef absl::optional NullableBytes; /** - * Represents a sequence of objects of a given type T. - * Type T can be either a primitive type (e.g. STRING) or a structure. - * First, the length N is given as an INT32. - * Then N instances of type T follow. - * A null array is represented with a length of -1. + * Kafka array of elements of type T */ template using NullableArray = absl::optional>; diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h index b54c3743bbf07..7551888fe05e3 100644 --- a/source/extensions/filters/network/kafka/message.h +++ b/source/extensions/filters/network/kafka/message.h @@ -16,11 +16,6 @@ namespace Kafka { class Message { public: virtual ~Message() = default; - - friend std::ostream& operator<<(std::ostream& out, const Message& arg) { return arg.print(out); } - -protected: - virtual std::ostream& print(std::ostream& os) const PURE; }; typedef std::shared_ptr MessageSharedPtr; diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h index e1b6288468f29..3e81bee3b6df9 100644 --- a/source/extensions/filters/network/kafka/messages/offset_commit.h +++ b/source/extensions/filters/network/kafka/messages/offset_commit.h @@ -44,11 +44,6 @@ struct OffsetCommitPartition { return partition_ == rhs.partition_ && offset_ == rhs.offset_ && timestamp_ == rhs.timestamp_ && metadata_ == rhs.metadata_; }; - - friend std::ostream& operator<<(std::ostream& os, const OffsetCommitPartition& arg) { - return os << "{partition=" << arg.partition_ << ", offset=" << arg.offset_ - << ", timestamp=" << arg.timestamp_ << ", metadata=" << arg.metadata_ << "}"; - } }; /** @@ -68,10 +63,6 @@ struct OffsetCommitTopic { bool operator==(const OffsetCommitTopic& rhs) const { return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; }; - - friend std::ostream& operator<<(std::ostream& os, const OffsetCommitTopic& arg) { - return os << "{topic=" << arg.topic_ << ", partitions_=" << arg.partitions_ << "}"; - } }; /** @@ -116,12 +107,6 @@ class OffsetCommitRequest : public Request { return written; } - std::ostream& printDetails(std::ostream& os) const override { - return os << "{group_id=" << group_id_ << ", group_generation_id=" << group_generation_id_ - << ", member_id=" << member_id_ << ", retention_time=" << retention_time_ - << ", topics=" << topics_ << "}"; - } - private: const std::string group_id_; const int32_t group_generation_id_; // since v1 @@ -130,45 +115,51 @@ class OffsetCommitRequest : public Request { const NullableArray topics_; }; -// clang-format off - // api version 0 // Deserializes bytes into OffsetCommitPartition (api version 0): partition, offset, metadata class OffsetCommitPartitionV0Buffer - : public CompositeDeserializerWith3Delegates {}; + : public CompositeDeserializerWith3Delegates {}; // Deserializes array of OffsetCommitPartition-s v0 class OffsetCommitPartitionV0ArrayBuffer : public ArrayDeserializer {}; // Deserializes bytes into OffsetCommitTopic (api version 0): topic name, partitions (v0) class OffsetCommitTopicV0Buffer - : public CompositeDeserializerWith2Delegates {}; + : public CompositeDeserializerWith2Delegates {}; // Deserializes array of OffsetCommitTopic-s v0 class OffsetCommitTopicV0ArrayBuffer : public ArrayDeserializer {}; // Deserializes bytes into OffsetCommitRequest (api version 0): group_id, topics (v0) class OffsetCommitRequestV0Deserializer - : public CompositeDeserializerWith2Delegates {}; + : public CompositeDeserializerWith2Delegates {}; // api version 1 -// Deserializes bytes into OffsetCommitPartition (api version 1): partition, offset, timestamp, metadata +// Deserializes bytes into OffsetCommitPartition (api version 1): partition, offset, timestamp, +// metadata class OffsetCommitPartitionV1Buffer - : public CompositeDeserializerWith4Delegates {}; + : public CompositeDeserializerWith4Delegates {}; // Deserializes array of OffsetCommitPartition-s v1 class OffsetCommitPartitionV1ArrayBuffer : public ArrayDeserializer {}; // Deserializes bytes into OffsetCommitTopic (api version 1): topic name, partitions (v1) class OffsetCommitTopicV1Buffer - : public CompositeDeserializerWith2Delegates {}; + : public CompositeDeserializerWith2Delegates {}; // Deserializes array of OffsetCommitTopic-s v1 class OffsetCommitTopicV1ArrayBuffer : public ArrayDeserializer {}; -// Deserializes bytes into OffsetCommitRequest (api version 1): group_id, generation_id, member_id, topics (v1) +// Deserializes bytes into OffsetCommitRequest (api version 1): group_id, generation_id, member_id, +// topics (v1) class OffsetCommitRequestV1Deserializer - : public CompositeDeserializerWith4Delegates {}; - -// clang-format on + : public CompositeDeserializerWith4Delegates {}; /** * Define Parsers that wrap the corresponding deserializers diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 04468f2978395..b51053337a7cb 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -41,7 +41,6 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl // next parser is not present, so we have finished parsing a message MessageSharedPtr message = result.message_; - ENVOY_LOG(trace, "parsed message: {}", *message); for (auto& callback : callbacks_) { callback->onMessage(result.message_); } diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index 77e40fccdd856..f7931b8a5bb35 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -14,13 +14,17 @@ namespace NetworkFilters { namespace Kafka { /** - * Invoked when request is successfully decoded + * Callback invoked when request is successfully decoded */ class RequestCallback { public: virtual ~RequestCallback() = default; - virtual void onMessage(MessageSharedPtr) PURE; + /** + * Callback method invoked when request is successfully decoded + * @param request request that has been decoded + */ + virtual void onMessage(MessageSharedPtr request) PURE; }; typedef std::shared_ptr RequestCallbackSharedPtr; @@ -53,7 +57,7 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable { * Values 0 and 1 are used to represent false and true respectively. * When reading a boolean value, any non-zero value is considered true. */ -class BoolBuffer : public Deserializer { +class BooleanDeserializer : public Deserializer { public: - BoolBuffer(){}; + BooleanDeserializer(){}; size_t feed(const char*& buffer, uint64_t& remaining) { return buffer_.feed(buffer, remaining); } @@ -172,7 +172,7 @@ class BoolBuffer : public Deserializer { * First reads length (INT16) and then allocates the buffer of given length * * From documentation: - * First the length N is given as an int16_t. + * First the length N is given as an INT16. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * Length must not be negative. */ @@ -231,7 +231,7 @@ class StringDeserializer : public Deserializer { * null value) * * From documentation: - * For non-null strings, first the length N is given as an int16_t. + * For non-null strings, first the length N is given as an INT16. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * A null value is encoded with length of -1 and there are no following bytes. */ @@ -298,22 +298,151 @@ class NullableStringDeserializer : public Deserializer { bool ready_{false}; }; +/** + * Deserializer of bytes value + * First reads length (INT32) and then allocates the buffer of given length + * + * From documentation: + * First the length N is given as an INT32. Then N bytes follow. + */ +class BytesDeserializer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid BYTES length: {}", required_)); + } + length_consumed_ = true; + } + + const size_t data_consumed = std::min(required_, remaining); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, buffer, data_consumed); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + Bytes get() const { return data_buf_; } + +private: + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + + std::vector data_buf_; + bool ready_{false}; +}; + +/** + * Deserializer of nullable bytes value + * First reads length (INT32) and then allocates the buffer of given length + * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning + * null value) + * + * From documentation: + * For non-null values, first the length N is given as an INT32. Then N bytes follow. + * A null value is encoded with length of -1 and there are no following bytes. + */ +class NullableBytesDeserializer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + + if (required_ >= 0) { + data_buf_ = std::vector(required_); + } + if (required_ == NULL_BYTES_LENGTH) { + ready_ = true; + } + if (required_ < NULL_BYTES_LENGTH) { + throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); + } + + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + const size_t data_consumed = std::min(required_, remaining); + const size_t written = data_buf_.size() - required_; + memcpy(data_buf_.data() + written, buffer, data_consumed); + required_ -= data_consumed; + + buffer += data_consumed; + remaining -= data_consumed; + + if (required_ == 0) { + ready_ = true; + } + + return length_consumed + data_consumed; + } + + bool ready() const { return ready_; } + + NullableBytes get() const { + if (NULL_BYTES_LENGTH == required_) { + return absl::nullopt; + } else { + return {data_buf_}; + } + } + +private: + constexpr static int32_t NULL_BYTES_LENGTH{-1}; + + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + + std::vector data_buf_; + bool ready_{false}; +}; + /** * Deserializer for array of objects of the same type * - * First reads the length of the array, then initializes N underlying deserializers of type CT - * After the last of N deserializers is ready, the results of each of them are gathered and put in a - * vector - * @param RT result type returned by deserializer CT - * @param CT underlying deserializer type + * First reads the length of the array, then initializes N underlying deserializers of type + * DeserializerType After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector + * @param ResponseType result type returned by deserializer of type DeserializerType + * @param DeserializerType underlying deserializer type * * From documentation: * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T * follow. A null array is represented with a length of -1. */ -template -class ArrayDeserializer : public Deserializer> { +template +class ArrayDeserializer : public Deserializer> { public: size_t feed(const char*& buffer, uint64_t& remaining) { @@ -327,7 +456,7 @@ class ArrayDeserializer : public Deserializer> { required_ = length_buf_.get(); if (required_ >= 0) { - children_ = std::vector(required_); + children_ = std::vector(required_); } if (required_ == NULL_ARRAY_LENGTH) { ready_ = true; @@ -344,12 +473,12 @@ class ArrayDeserializer : public Deserializer> { } size_t child_consumed{0}; - for (CT& child : children_) { + for (DeserializerType& child : children_) { child_consumed += child.feed(buffer, remaining); } bool children_ready_ = true; - for (CT& child : children_) { + for (DeserializerType& child : children_) { children_ready_ &= child.ready(); } ready_ = children_ready_; @@ -359,12 +488,12 @@ class ArrayDeserializer : public Deserializer> { bool ready() const { return ready_; } - NullableArray get() const { + NullableArray get() const { if (NULL_ARRAY_LENGTH != required_) { - std::vector result{}; + std::vector result{}; result.reserve(children_.size()); - for (const CT& child : children_) { - const RT child_result = child.get(); + for (const DeserializerType& child : children_) { + const ResponseType child_result = child.get(); result.push_back(child_result); } return {result}; @@ -379,24 +508,11 @@ class ArrayDeserializer : public Deserializer> { Int32Deserializer length_buf_; bool length_consumed_{false}; int32_t required_; - std::vector children_; + std::vector children_; bool children_setup_{false}; bool ready_{false}; }; -/** - * Trivial deserializer that is always ready, and consumes no bytes - * Used in situations when value is always present and returns a constant - */ -template class NullDeserializer : public Deserializer { -public: - size_t feed(const char*&, uint64_t&) { return 0; } - - bool ready() const { return true; } - - RT get() const { return {}; } -}; - /** * Encodes provided argument in Kafka format * In case of primitive types, this is done explicitly as per spec diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h index 6147dabc2dab0..14f96ae8a1ad2 100644 --- a/source/extensions/filters/network/kafka/serialization_composite.h +++ b/source/extensions/filters/network/kafka/serialization_composite.h @@ -22,8 +22,8 @@ namespace Kafka { /** * This header contains only composite deserializers - * The basic design is composite deserializer creating delegates T1..Tn - * Result of type RT is constructed by getting results of each of delegates + * The basic design is composite deserializer creating delegates DeserializerType1..Tn + * Result of type ResponseType is constructed by getting results of each of delegates */ /** @@ -32,14 +32,14 @@ namespace Kafka { * (deserializers that are already ready do not consume data, so it's safe). * The composite deserializer is ready when the last deserializer is ready * (which means all deserializers before it are ready too) - * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } * - * @param RT type of deserialized data - * @param T1 1st deserializer (result used as 1st argument of RT's ctor) - * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) + * @param ResponseType type of deserialized data + * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) + * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) */ -template -class CompositeDeserializerWith2Delegates : public Deserializer { +template +class CompositeDeserializerWith2Delegates : public Deserializer { public: CompositeDeserializerWith2Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -49,11 +49,11 @@ class CompositeDeserializerWith2Delegates : public Deserializer { return consumed; } bool ready() const { return delegate2_.ready(); } - RT get() const { return {delegate1_.get(), delegate2_.get()}; } + ResponseType get() const { return {delegate1_.get(), delegate2_.get()}; } protected: - T1 delegate1_; - T2 delegate2_; + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; }; /** @@ -62,15 +62,16 @@ class CompositeDeserializerWith2Delegates : public Deserializer { * (deserializers that are already ready do not consume data, so it's safe). * The composite deserializer is ready when the last deserializer is ready * (which means all deserializers before it are ready too) - * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } * - * @param RT type of deserialized data - * @param T1 1st deserializer (result used as 1st argument of RT's ctor) - * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) - * @param T3 3rd deserializer (result used as 3rd argument of RT's ctor) + * @param ResponseType type of deserialized data + * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) + * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) + * @param DeserializerType3 3rd deserializer (result used as 3rd argument of ResponseType's ctor) */ -template -class CompositeDeserializerWith3Delegates : public Deserializer { +template +class CompositeDeserializerWith3Delegates : public Deserializer { public: CompositeDeserializerWith3Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -81,12 +82,12 @@ class CompositeDeserializerWith3Delegates : public Deserializer { return consumed; } bool ready() const { return delegate3_.ready(); } - RT get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } + ResponseType get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } protected: - T1 delegate1_; - T2 delegate2_; - T3 delegate3_; + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; }; /** @@ -95,16 +96,17 @@ class CompositeDeserializerWith3Delegates : public Deserializer { * (deserializers that are already ready do not consume data, so it's safe). * The composite deserializer is ready when the last deserializer is ready * (which means all deserializers before it are ready too) - * Constructs the result of type RT using { delegate1_.get(), delegate2_.get() ... } + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } * - * @param RT type of deserialized data - * @param T1 1st deserializer (result used as 1st argument of RT's ctor) - * @param T2 2nd deserializer (result used as 2nd argument of RT's ctor) - * @param T3 3rd deserializer (result used as 3rd argument of RT's ctor) - * @param T4 4th deserializer (result used as 4th argument of RT's ctor) + * @param ResponseType type of deserialized data + * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) + * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) + * @param DeserializerType3 3rd deserializer (result used as 3rd argument of ResponseType's ctor) + * @param DeserializerType4 4th deserializer (result used as 4th argument of ResponseType's ctor) */ -template -class CompositeDeserializerWith4Delegates : public Deserializer { +template +class CompositeDeserializerWith4Delegates : public Deserializer { public: CompositeDeserializerWith4Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -116,15 +118,15 @@ class CompositeDeserializerWith4Delegates : public Deserializer { return consumed; } bool ready() const { return delegate4_.ready(); } - RT get() const { + ResponseType get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; } protected: - T1 delegate1_; - T2 delegate2_; - T3 delegate3_; - T4 delegate4_; + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; }; } // namespace Kafka diff --git a/source/extensions/filters/network/well_known_names.h b/source/extensions/filters/network/well_known_names.h index 466178b37ed8f..6a68c32223c41 100644 --- a/source/extensions/filters/network/well_known_names.h +++ b/source/extensions/filters/network/well_known_names.h @@ -30,8 +30,6 @@ class NetworkFilterNameValues { const std::string TcpProxy = "envoy.tcp_proxy"; // Authorization filter const std::string ExtAuthorization = "envoy.ext_authz"; - // Kafka filter - const std::string Kafka = "envoy.filters.network.kafka"; // Thrift proxy filter const std::string ThriftProxy = "envoy.filters.network.thrift_proxy"; // Role based access control filter diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 9555b3e639c9b..65655ec7a460c 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -32,9 +32,13 @@ TEST_EmptyDeserializerShouldNotBeReady(Int16Deserializer); TEST_EmptyDeserializerShouldNotBeReady(Int32Deserializer); TEST_EmptyDeserializerShouldNotBeReady(UInt32Deserializer); TEST_EmptyDeserializerShouldNotBeReady(Int64Deserializer); -TEST_EmptyDeserializerShouldNotBeReady(BoolBuffer); +TEST_EmptyDeserializerShouldNotBeReady(BooleanDeserializer); + TEST_EmptyDeserializerShouldNotBeReady(StringDeserializer); TEST_EmptyDeserializerShouldNotBeReady(NullableStringDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(BytesDeserializer); +TEST_EmptyDeserializerShouldNotBeReady(NullableBytesDeserializer); + TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { // given struct CompositeResult { @@ -74,15 +78,6 @@ TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { ASSERT_EQ(testee.ready(), false); } -// Null deserializer is a special case, it's always ready and can provide results via 0-arg ctor -TEST(NullDeserializer, EmptyBufferShouldBeReady) { - // given - const NullDeserializer testee{}; - // when, then - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), 0); -} - EncodingContext encoder{-1}; // context is not used when serializing primitive types // helper function @@ -191,7 +186,7 @@ TEST_DeserializerShouldDeserialize(Int16Deserializer, int16_t, 42); TEST_DeserializerShouldDeserialize(Int32Deserializer, int32_t, 42); TEST_DeserializerShouldDeserialize(UInt32Deserializer, uint32_t, 42); TEST_DeserializerShouldDeserialize(Int64Deserializer, int64_t, 42); -TEST_DeserializerShouldDeserialize(BoolBuffer, bool, true); +TEST_DeserializerShouldDeserialize(BooleanDeserializer, bool, true); TEST(StringDeserializer, ShouldDeserialize) { const std::string value = "sometext"; @@ -253,6 +248,63 @@ TEST(NullableStringDeserializer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } +TEST(BytesDeserializer, ShouldDeserialize) { + const Bytes value{'a', 'b', 'c', 'd'}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(BytesDeserializer, ShouldDeserializeEmptyBytes) { + const Bytes value{}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(BytesDeserializer, ShouldThrowOnInvalidLength) { + // given + BytesDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t bytes_length = -1; // BYTES accepts only >= 0 + encoder.encode(bytes_length, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +TEST(NullableBytesDeserializer, ShouldDeserialize) { + const NullableBytes value{{'a', 'b', 'c', 'd'}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldDeserializeEmptyBytes) { + const NullableBytes value{{}}; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldDeserializeNullBytes) { + const NullableBytes value = absl::nullopt; + serializeThenDeserializeAndCheckEquality(value); +} + +TEST(NullableBytesDeserializer, ShouldThrowOnInvalidLength) { + // given + NullableBytesDeserializer testee; + Buffer::OwnedImpl buffer; + + const int32_t bytes_length = -2; // -1 is OK for NULLABLE_BYTES + encoder.encode(bytes_length, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + TEST(ArrayDeserializer, ShouldConsumeCorrectAmountOfData) { const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; serializeThenDeserializeAndCheckEquality>( From d0534c3bfdc9e4b0c30fa4f43ef70c27c818d5de Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 10 Jan 2019 13:35:55 +0000 Subject: [PATCH 10/29] Introduce generated code for Kafka requests: - requests.h - Kafka request specification - kafka_request_resolver.cc - mapping from api version + api key to request - serialization_composite.h - composite deserializers - requests_test.cc - tests for serialization of requests Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 10 +- .../kafka/generated/kafka_request_resolver.cc | 36 ++ .../network/kafka/generated/requests.h | 448 +++++++++++++ .../kafka/generated/serialization_composite.h | 609 ++++++++++++++++++ .../filters/network/kafka/kafka_protocol.h | 19 - .../filters/network/kafka/kafka_request.cc | 67 +- .../filters/network/kafka/kafka_request.h | 84 +-- .../network/kafka/messages/offset_commit.h | 174 ----- .../filters/network/kafka/request_codec.cc | 7 +- .../filters/network/kafka/request_codec.h | 4 +- .../filters/network/kafka/serialization.h | 11 - .../network/kafka/serialization_composite.h | 135 ---- test/extensions/filters/network/kafka/BUILD | 20 + .../network/kafka/generated/requests_test.cc | 95 +++ .../generated/serialization_composite_test.cc | 484 ++++++++++++++ .../network/kafka/kafka_request_test.cc | 65 +- .../network/kafka/request_codec_test.cc | 63 +- .../network/kafka/serialization_test.cc | 124 +--- 18 files changed, 1747 insertions(+), 708 deletions(-) create mode 100644 source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc create mode 100644 source/extensions/filters/network/kafka/generated/requests.h create mode 100644 source/extensions/filters/network/kafka/generated/serialization_composite.h delete mode 100644 source/extensions/filters/network/kafka/kafka_protocol.h delete mode 100644 source/extensions/filters/network/kafka/messages/offset_commit.h delete mode 100644 source/extensions/filters/network/kafka/serialization_composite.h create mode 100644 test/extensions/filters/network/kafka/generated/requests_test.cc create mode 100644 test/extensions/filters/network/kafka/generated/serialization_composite_test.cc diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index cbd1bc719cf54..e4c73e532b025 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -26,10 +26,13 @@ envoy_cc_library( envoy_cc_library( name = "kafka_request_lib", - srcs = ["kafka_request.cc"], + srcs = [ + "generated/kafka_request_resolver.cc", + "kafka_request.cc", + ], hdrs = [ + "generated/requests.h", "kafka_request.h", - "messages/offset_commit.h", ], deps = [ ":parser_lib", @@ -61,8 +64,8 @@ envoy_cc_library( envoy_cc_library( name = "serialization_lib", hdrs = [ + "generated/serialization_composite.h", "serialization.h", - "serialization_composite.h", ], deps = [ ":kafka_protocol_lib", @@ -74,7 +77,6 @@ envoy_cc_library( envoy_cc_library( name = "kafka_protocol_lib", hdrs = [ - "kafka_protocol.h", "kafka_types.h", ], external_deps = ["abseil_optional"], diff --git a/source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc b/source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc new file mode 100644 index 0000000000000..2e8c26f25bad5 --- /dev/null +++ b/source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc @@ -0,0 +1,36 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#include "extensions/filters/network/kafka/generated/requests.h" +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +const RequestParserResolver RequestParserResolver::INSTANCE; + +ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const { + + if (8 == api_key && 0 == api_version) { + return std::make_shared(context); + } + if (8 == api_key && 1 == api_version) { + return std::make_shared(context); + } + if (8 == api_key && 2 == api_version) { + return std::make_shared(context); + } + if (8 == api_key && 3 == api_version) { + return std::make_shared(context); + } + return std::make_shared(context); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy +// clang-format on diff --git a/source/extensions/filters/network/kafka/generated/requests.h b/source/extensions/filters/network/kafka/generated/requests.h new file mode 100644 index 0000000000000..566f3ea012a8b --- /dev/null +++ b/source/extensions/filters/network/kafka/generated/requests.h @@ -0,0 +1,448 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#pragma once +#include "extensions/filters/network/kafka/kafka_request.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/* Represents 'partitions' element in OffsetCommitRequestV0 */ +struct OffsetCommitRequestV0Partition { + const int32_t partition_; + const int64_t offset_; + const NullableString metadata_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV0Partition& rhs) const { + return + partition_ == rhs.partition_ && + offset_ == rhs.offset_ && + metadata_ == rhs.metadata_; + }; + +}; + +class OffsetCommitRequestV0PartitionDeserializer: + public CompositeDeserializerWith3Delegates< + OffsetCommitRequestV0Partition, + Int32Deserializer, + Int64Deserializer, + NullableStringDeserializer + >{}; + +/* Represents 'topics' element in OffsetCommitRequestV0 */ +struct OffsetCommitRequestV0Topic { + const std::string topic_; + const NullableArray partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV0Topic& rhs) const { + return + topic_ == rhs.topic_ && + partitions_ == rhs.partitions_; + }; + +}; + +class OffsetCommitRequestV0TopicDeserializer: + public CompositeDeserializerWith2Delegates< + OffsetCommitRequestV0Topic, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV0 : public Request { +public: + OffsetCommitRequestV0( + std::string group_id, + NullableArray topics + ): + Request{8, 0}, + group_id_{group_id}, + topics_{topics} + {}; + + bool operator==(const OffsetCommitRequestV0& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + written += encoder.encode(topics_, dst); + return written; + } + +private: + const std::string group_id_; + const NullableArray topics_; +}; + +class OffsetCommitRequestV0Deserializer: + public CompositeDeserializerWith2Delegates< + OffsetCommitRequestV0, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV0Parser : public RequestParser { +public: + OffsetCommitRequestV0Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; +}; + +/* Represents 'partitions' element in OffsetCommitRequestV1 */ +struct OffsetCommitRequestV1Partition { + const int32_t partition_; + const int64_t offset_; + const int64_t timestamp_; + const NullableString metadata_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + written += encoder.encode(timestamp_, dst); + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV1Partition& rhs) const { + return + partition_ == rhs.partition_ && + offset_ == rhs.offset_ && + timestamp_ == rhs.timestamp_ && + metadata_ == rhs.metadata_; + }; + +}; + +class OffsetCommitRequestV1PartitionDeserializer: + public CompositeDeserializerWith4Delegates< + OffsetCommitRequestV1Partition, + Int32Deserializer, + Int64Deserializer, + Int64Deserializer, + NullableStringDeserializer + >{}; + +/* Represents 'topics' element in OffsetCommitRequestV1 */ +struct OffsetCommitRequestV1Topic { + const std::string topic_; + const NullableArray partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV1Topic& rhs) const { + return + topic_ == rhs.topic_ && + partitions_ == rhs.partitions_; + }; + +}; + +class OffsetCommitRequestV1TopicDeserializer: + public CompositeDeserializerWith2Delegates< + OffsetCommitRequestV1Topic, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV1 : public Request { +public: + OffsetCommitRequestV1( + std::string group_id, + int32_t generation_id, + std::string member_id, + NullableArray topics + ): + Request{8, 1}, + group_id_{group_id}, + generation_id_{generation_id}, + member_id_{member_id}, + topics_{topics} + {}; + + bool operator==(const OffsetCommitRequestV1& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + written += encoder.encode(generation_id_, dst); + written += encoder.encode(member_id_, dst); + written += encoder.encode(topics_, dst); + return written; + } + +private: + const std::string group_id_; + const int32_t generation_id_; + const std::string member_id_; + const NullableArray topics_; +}; + +class OffsetCommitRequestV1Deserializer: + public CompositeDeserializerWith4Delegates< + OffsetCommitRequestV1, + StringDeserializer, + Int32Deserializer, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV1Parser : public RequestParser { +public: + OffsetCommitRequestV1Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; +}; + +/* Represents 'partitions' element in OffsetCommitRequestV2 */ +struct OffsetCommitRequestV2Partition { + const int32_t partition_; + const int64_t offset_; + const NullableString metadata_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV2Partition& rhs) const { + return + partition_ == rhs.partition_ && + offset_ == rhs.offset_ && + metadata_ == rhs.metadata_; + }; + +}; + +class OffsetCommitRequestV2PartitionDeserializer: + public CompositeDeserializerWith3Delegates< + OffsetCommitRequestV2Partition, + Int32Deserializer, + Int64Deserializer, + NullableStringDeserializer + >{}; + +/* Represents 'topics' element in OffsetCommitRequestV2 */ +struct OffsetCommitRequestV2Topic { + const std::string topic_; + const NullableArray partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV2Topic& rhs) const { + return + topic_ == rhs.topic_ && + partitions_ == rhs.partitions_; + }; + +}; + +class OffsetCommitRequestV2TopicDeserializer: + public CompositeDeserializerWith2Delegates< + OffsetCommitRequestV2Topic, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV2 : public Request { +public: + OffsetCommitRequestV2( + std::string group_id, + int32_t generation_id, + std::string member_id, + int64_t retention_time, + NullableArray topics + ): + Request{8, 2}, + group_id_{group_id}, + generation_id_{generation_id}, + member_id_{member_id}, + retention_time_{retention_time}, + topics_{topics} + {}; + + bool operator==(const OffsetCommitRequestV2& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + written += encoder.encode(generation_id_, dst); + written += encoder.encode(member_id_, dst); + written += encoder.encode(retention_time_, dst); + written += encoder.encode(topics_, dst); + return written; + } + +private: + const std::string group_id_; + const int32_t generation_id_; + const std::string member_id_; + const int64_t retention_time_; + const NullableArray topics_; +}; + +class OffsetCommitRequestV2Deserializer: + public CompositeDeserializerWith5Delegates< + OffsetCommitRequestV2, + StringDeserializer, + Int32Deserializer, + StringDeserializer, + Int64Deserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV2Parser : public RequestParser { +public: + OffsetCommitRequestV2Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; +}; + +/* Represents 'partitions' element in OffsetCommitRequestV3 */ +struct OffsetCommitRequestV3Partition { + const int32_t partition_; + const int64_t offset_; + const NullableString metadata_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(partition_, dst); + written += encoder.encode(offset_, dst); + written += encoder.encode(metadata_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV3Partition& rhs) const { + return + partition_ == rhs.partition_ && + offset_ == rhs.offset_ && + metadata_ == rhs.metadata_; + }; + +}; + +class OffsetCommitRequestV3PartitionDeserializer: + public CompositeDeserializerWith3Delegates< + OffsetCommitRequestV3Partition, + Int32Deserializer, + Int64Deserializer, + NullableStringDeserializer + >{}; + +/* Represents 'topics' element in OffsetCommitRequestV3 */ +struct OffsetCommitRequestV3Topic { + const std::string topic_; + const NullableArray partitions_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(topic_, dst); + written += encoder.encode(partitions_, dst); + return written; + } + + bool operator==(const OffsetCommitRequestV3Topic& rhs) const { + return + topic_ == rhs.topic_ && + partitions_ == rhs.partitions_; + }; + +}; + +class OffsetCommitRequestV3TopicDeserializer: + public CompositeDeserializerWith2Delegates< + OffsetCommitRequestV3Topic, + StringDeserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV3 : public Request { +public: + OffsetCommitRequestV3( + std::string group_id, + int32_t generation_id, + std::string member_id, + int64_t retention_time, + NullableArray topics + ): + Request{8, 3}, + group_id_{group_id}, + generation_id_{generation_id}, + member_id_{member_id}, + retention_time_{retention_time}, + topics_{topics} + {}; + + bool operator==(const OffsetCommitRequestV3& rhs) const { + return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; + }; + +protected: + size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { + size_t written{0}; + written += encoder.encode(group_id_, dst); + written += encoder.encode(generation_id_, dst); + written += encoder.encode(member_id_, dst); + written += encoder.encode(retention_time_, dst); + written += encoder.encode(topics_, dst); + return written; + } + +private: + const std::string group_id_; + const int32_t generation_id_; + const std::string member_id_; + const int64_t retention_time_; + const NullableArray topics_; +}; + +class OffsetCommitRequestV3Deserializer: + public CompositeDeserializerWith5Delegates< + OffsetCommitRequestV3, + StringDeserializer, + Int32Deserializer, + StringDeserializer, + Int64Deserializer, + ArrayDeserializer + >{}; + +class OffsetCommitRequestV3Parser : public RequestParser { +public: + OffsetCommitRequestV3Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; +}; + +}}}} +// clang-format on diff --git a/source/extensions/filters/network/kafka/generated/serialization_composite.h b/source/extensions/filters/network/kafka/generated/serialization_composite.h new file mode 100644 index 0000000000000..2a7b2ae261516 --- /dev/null +++ b/source/extensions/filters/network/kafka/generated/serialization_composite.h @@ -0,0 +1,609 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#pragma once + +#include +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * This header contains only composite deserializers + * The basic design is composite deserializer creating delegates DeserializerType1..N + * Result of type ResponseType is constructed by getting results of each of delegates + */ + +/** + * Composite deserializer that uses 0 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + */ +template < + typename ResponseType +> +class CompositeDeserializerWith0Delegates : public Deserializer { +public: + + CompositeDeserializerWith0Delegates(){}; + + size_t feed(const char*&, uint64_t&) { + return 0; + } + + bool ready() const { + return true; + } + + ResponseType get() const { + return { + }; + } + +protected: +}; + +/** + * Composite deserializer that uses 1 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1 +> +class CompositeDeserializerWith1Delegates : public Deserializer { +public: + + CompositeDeserializerWith1Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate1_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get() + }; + } + +protected: + DeserializerType1 delegate1_; +}; + +/** + * Composite deserializer that uses 2 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2 +> +class CompositeDeserializerWith2Delegates : public Deserializer { +public: + + CompositeDeserializerWith2Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate2_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; +}; + +/** + * Composite deserializer that uses 3 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3 +> +class CompositeDeserializerWith3Delegates : public Deserializer { +public: + + CompositeDeserializerWith3Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate3_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; +}; + +/** + * Composite deserializer that uses 4 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4 +> +class CompositeDeserializerWith4Delegates : public Deserializer { +public: + + CompositeDeserializerWith4Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate4_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; +}; + +/** + * Composite deserializer that uses 5 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4, + typename DeserializerType5 +> +class CompositeDeserializerWith5Delegates : public Deserializer { +public: + + CompositeDeserializerWith5Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + consumed += delegate5_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate5_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get(), + delegate5_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; + DeserializerType5 delegate5_; +}; + +/** + * Composite deserializer that uses 6 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) + * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4, + typename DeserializerType5, + typename DeserializerType6 +> +class CompositeDeserializerWith6Delegates : public Deserializer { +public: + + CompositeDeserializerWith6Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + consumed += delegate5_.feed(buffer, remaining); + consumed += delegate6_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate6_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get(), + delegate5_.get(), + delegate6_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; + DeserializerType5 delegate5_; + DeserializerType6 delegate6_; +}; + +/** + * Composite deserializer that uses 7 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) + * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) + * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4, + typename DeserializerType5, + typename DeserializerType6, + typename DeserializerType7 +> +class CompositeDeserializerWith7Delegates : public Deserializer { +public: + + CompositeDeserializerWith7Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + consumed += delegate5_.feed(buffer, remaining); + consumed += delegate6_.feed(buffer, remaining); + consumed += delegate7_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate7_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get(), + delegate5_.get(), + delegate6_.get(), + delegate7_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; + DeserializerType5 delegate5_; + DeserializerType6 delegate6_; + DeserializerType7 delegate7_; +}; + +/** + * Composite deserializer that uses 8 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) + * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) + * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) + * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4, + typename DeserializerType5, + typename DeserializerType6, + typename DeserializerType7, + typename DeserializerType8 +> +class CompositeDeserializerWith8Delegates : public Deserializer { +public: + + CompositeDeserializerWith8Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + consumed += delegate5_.feed(buffer, remaining); + consumed += delegate6_.feed(buffer, remaining); + consumed += delegate7_.feed(buffer, remaining); + consumed += delegate8_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate8_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get(), + delegate5_.get(), + delegate6_.get(), + delegate7_.get(), + delegate8_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; + DeserializerType5 delegate5_; + DeserializerType6 delegate6_; + DeserializerType7 delegate7_; + DeserializerType8 delegate8_; +}; + +/** + * Composite deserializer that uses 9 deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data + * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) + * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) + * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) + * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) + * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) + * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) + * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) + * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) + * @param DeserializerType9 deserializer 9 (result used as argument 9 of ResponseType's ctor) + */ +template < + typename ResponseType, + typename DeserializerType1, + typename DeserializerType2, + typename DeserializerType3, + typename DeserializerType4, + typename DeserializerType5, + typename DeserializerType6, + typename DeserializerType7, + typename DeserializerType8, + typename DeserializerType9 +> +class CompositeDeserializerWith9Delegates : public Deserializer { +public: + + CompositeDeserializerWith9Delegates(){}; + + size_t feed(const char*& buffer, uint64_t& remaining) { + size_t consumed = 0; + consumed += delegate1_.feed(buffer, remaining); + consumed += delegate2_.feed(buffer, remaining); + consumed += delegate3_.feed(buffer, remaining); + consumed += delegate4_.feed(buffer, remaining); + consumed += delegate5_.feed(buffer, remaining); + consumed += delegate6_.feed(buffer, remaining); + consumed += delegate7_.feed(buffer, remaining); + consumed += delegate8_.feed(buffer, remaining); + consumed += delegate9_.feed(buffer, remaining); + return consumed; + } + + bool ready() const { + return delegate9_.ready(); + } + + ResponseType get() const { + return { + delegate1_.get(), + delegate2_.get(), + delegate3_.get(), + delegate4_.get(), + delegate5_.get(), + delegate6_.get(), + delegate7_.get(), + delegate8_.get(), + delegate9_.get() + }; + } + +protected: + DeserializerType1 delegate1_; + DeserializerType2 delegate2_; + DeserializerType3 delegate3_; + DeserializerType4 delegate4_; + DeserializerType5 delegate5_; + DeserializerType6 delegate6_; + DeserializerType7 delegate7_; + DeserializerType8 delegate8_; + DeserializerType9 delegate9_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy +// clang-format on diff --git a/source/extensions/filters/network/kafka/kafka_protocol.h b/source/extensions/filters/network/kafka/kafka_protocol.h deleted file mode 100644 index 3d1f07498b37c..0000000000000 --- a/source/extensions/filters/network/kafka/kafka_protocol.h +++ /dev/null @@ -1,19 +0,0 @@ -#pragma once - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * Kafka request type identifier (int16_t value present in header of every request) - * @see http://kafka.apache.org/protocol.html#protocol_api_keys - */ -enum RequestType : int16_t { - OffsetCommit = 8, -}; - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request.cc index 529822be8a069..b41d0ea714c9a 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request.cc @@ -1,7 +1,6 @@ #include "extensions/filters/network/kafka/kafka_request.h" -#include "extensions/filters/network/kafka/kafka_protocol.h" -#include "extensions/filters/network/kafka/messages/offset_commit.h" +#include "extensions/filters/network/kafka/generated/requests.h" #include "extensions/filters/network/kafka/parser.h" namespace Envoy { @@ -9,70 +8,6 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -// helper function that generates a map from specs looking like { api_key, api_versions... } -ParserGenerators computeGeneratorMap(const ParserGenerators& original, - const std::vector specs) { - ParserGenerators result{original}; - for (auto& spec : specs) { - auto& generators = result[spec.api_key_]; - for (int16_t api_version : spec.api_versions_) { - generators[api_version] = spec.generator_; - } - } - - return result; -} - -RequestParserResolver::RequestParserResolver(const std::vector specs) - : generators_{computeGeneratorMap({}, specs)} {}; - -RequestParserResolver::RequestParserResolver(const RequestParserResolver& original, - const std::vector specs) - : generators_{computeGeneratorMap(original.generators_, specs)} {}; - -// helper macro binding request type & api versions to Deserializers -// the rendered function will create a new instance of (REQUEST)RequestV(Version)Parser -// e.g. OffsetCommitRequestV0Parser -#define PARSER_SPEC(REQUEST_NAME, PARSER_VERSION, ...) \ - ParserSpec { \ - RequestType::REQUEST_NAME, {__VA_ARGS__}, [](RequestContextSharedPtr arg) -> ParserSharedPtr { \ - return std::make_shared(arg); \ - } \ - } - -const RequestParserResolver RequestParserResolver::KAFKA_0_11{{ - PARSER_SPEC(OffsetCommit, V0, 0), PARSER_SPEC(OffsetCommit, V1, 1), - // XXX(adam.kotwasinski) missing request types here -}}; - -const RequestParserResolver RequestParserResolver::KAFKA_1_0{ - RequestParserResolver::KAFKA_0_11, - { - // XXX(adam.kotwasinski) missing request types & versions here - }}; - -ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, - RequestContextSharedPtr context) const { - - // api_key - const auto api_versions_ptr = generators_.find(api_key); - if (generators_.end() == api_versions_ptr) { - return std::make_shared(context); - } - const std::unordered_map& api_versions = - api_versions_ptr->second; - - // api_version - const auto generator_ptr = api_versions.find(api_version); - if (api_versions.end() == generator_ptr) { - return std::make_shared(context); - } - - // found matching parser generator, create parser - const ParserGeneratorFunction generator = generator_ptr->second; - return generator(context); -} - ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining) { request_length_.feed(buffer, remaining); if (request_length_.ready()) { diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 45b3f7d64dd74..aff49af69b43a 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -6,10 +6,9 @@ #include "common/common/assert.h" -#include "extensions/filters/network/kafka/kafka_protocol.h" +#include "extensions/filters/network/kafka/generated/serialization_composite.h" #include "extensions/filters/network/kafka/parser.h" #include "extensions/filters/network/kafka/serialization.h" -#include "extensions/filters/network/kafka/serialization_composite.h" namespace Envoy { namespace Extensions { @@ -42,27 +41,6 @@ struct RequestContext { typedef std::shared_ptr RequestContextSharedPtr; -/** - * Function generating a parser with given context - */ -typedef std::function ParserGeneratorFunction; - -/** - * Structure responsible for mapping [api_key, api_version] -> ParserGeneratorFunction - */ -typedef std::unordered_map> - ParserGenerators; - -/** - * Trivial structure specifying which parser generator function should be used - * for which api_key & api_version - */ -struct ParserSpec { - const int16_t api_key_; - const std::vector api_versions_; - const ParserGeneratorFunction generator_; -}; - /** * Configuration object * Resolves the parser that will be responsible for consuming the request-specific data @@ -70,17 +48,6 @@ struct ParserSpec { */ class RequestParserResolver { public: - /** - * Creates a resolver that uses generator functions provided by given specifications - */ - RequestParserResolver(const std::vector specs); - - /** - * Creates a resolver that uses generator functions provided by original resolver and then - * expanded by specifications - */ - RequestParserResolver(const RequestParserResolver& original, const std::vector specs); - virtual ~RequestParserResolver() = default; /** @@ -94,17 +61,9 @@ class RequestParserResolver { RequestContextSharedPtr context) const; /** - * Request versions handled by Kafka up to 0.11 + * Request parser singleton */ - static const RequestParserResolver KAFKA_0_11; - - /** - * Request versions handled by Kafka up to 1.0 - */ - static const RequestParserResolver KAFKA_1_0; - -private: - ParserGenerators generators_; + static const RequestParserResolver INSTANCE; }; /** @@ -189,7 +148,8 @@ template class RequestParser : // after a successful parse, there should be nothing left - we have consumed all the bytes ASSERT(0 == context_->remaining_request_size_); RequestType request = deserializer.get(); - request.header() = context_->request_header_; + const RequestHeader& parsed_header = context_->request_header_; + request.setMetadata(parsed_header.correlation_id_, parsed_header.client_id_); MessageSharedPtr msg = std::make_shared(request); return ParseResponse::parsedMessage(msg); } else { @@ -202,19 +162,6 @@ template class RequestParser : DeserializerType deserializer; // underlying request-specific deserializer }; -/** - * Helper macro defining RequestParser that uses the underlying Deserializer - * Aware of versioning - * Names of Deserializers/Parsers are influenced by org.apache.kafka.common.protocol.Protocol names - * Renders class named (Request)(Version)Parser e.g. OffsetCommitRequestV0Parser - */ -#define DEFINE_REQUEST_PARSER(REQUEST_TYPE, VERSION) \ - class REQUEST_TYPE##VERSION##Parser \ - : public RequestParser { \ - public: \ - REQUEST_TYPE##VERSION##Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; \ - }; - /** * Abstract Kafka request * Contains data present in every request @@ -225,18 +172,12 @@ class Request : public Message { /** * Request header fields need to be initialized by user in case of newly created requests */ - Request(int16_t api_key) : request_header_{api_key, 0, 0, ""} {}; - - Request(const RequestHeader& request_header) : request_header_{request_header} {}; - - RequestHeader& header() { return request_header_; } - - int16_t& apiVersion() { return request_header_.api_version_; } - int16_t apiVersion() const { return request_header_.api_version_; } + Request(int16_t api_key, int16_t api_version) : request_header_{api_key, api_version, 0, ""} {}; - int32_t& correlationId() { return request_header_.correlation_id_; } - - NullableString& clientId() { return request_header_.client_id_; } + void setMetadata(const int32_t correlation_id, const NullableString& client_id) { + request_header_.correlation_id_ = correlation_id; + request_header_.client_id_ = client_id; + } /** * Encodes given request into a buffer, with any extra configuration carried by the context @@ -268,7 +209,10 @@ class Request : public Message { */ class UnknownRequest : public Request { public: - UnknownRequest(const RequestHeader& request_header) : Request{request_header} {}; + UnknownRequest(const RequestHeader& request_header) + : Request{request_header.api_key_, request_header.api_version_} { + setMetadata(request_header.correlation_id_, request_header.client_id_); + }; protected: // this isn't the prettiest, as we have thrown away the data diff --git a/source/extensions/filters/network/kafka/messages/offset_commit.h b/source/extensions/filters/network/kafka/messages/offset_commit.h deleted file mode 100644 index 3e81bee3b6df9..0000000000000 --- a/source/extensions/filters/network/kafka/messages/offset_commit.h +++ /dev/null @@ -1,174 +0,0 @@ -#pragma once - -#include "extensions/filters/network/kafka/kafka_request.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * Generic description : http://kafka.apache.org/protocol.html#The_Messages_OffsetCommit - */ - -/** - * Holds the partition data: partition, offset, timestamp, metadata - */ -struct OffsetCommitPartition { - const int32_t partition_; - const int64_t offset_; - const int64_t timestamp_; // only v1 - const NullableString metadata_; - - // v0 *and* v2 - OffsetCommitPartition(int32_t partition, int64_t offset, NullableString metadata) - : partition_{partition}, offset_{offset}, timestamp_{-1}, metadata_{metadata} {}; - - // v1 - OffsetCommitPartition(int32_t partition, int64_t offset, int64_t timestamp, - NullableString metadata) - : partition_{partition}, offset_{offset}, timestamp_{timestamp}, metadata_{metadata} {}; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - if (encoder.apiVersion() == 1) { - written += encoder.encode(timestamp_, dst); - } - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitPartition& rhs) const { - return partition_ == rhs.partition_ && offset_ == rhs.offset_ && timestamp_ == rhs.timestamp_ && - metadata_ == rhs.metadata_; - }; -}; - -/** - * Holds the topic data: topic name and partitions in that topic - */ -struct OffsetCommitTopic { - const std::string topic_; - const NullableArray partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitTopic& rhs) const { - return topic_ == rhs.topic_ && partitions_ == rhs.partitions_; - }; -}; - -/** - * Holds the request: group id, generation id, member id, retention time, topics - */ -class OffsetCommitRequest : public Request { -public: - // v0 - OffsetCommitRequest(std::string group_id, NullableArray topics) - : OffsetCommitRequest(group_id, -1, "", -1, topics){}; - - // v1 - OffsetCommitRequest(std::string group_id, int32_t group_generation_id, std::string member_id, - NullableArray topics) - : OffsetCommitRequest(group_id, group_generation_id, member_id, -1, topics){}; - - // v2 .. v3 - OffsetCommitRequest(std::string group_id, int32_t group_generation_id, std::string member_id, - int64_t retention_time, NullableArray topics) - : Request{RequestType::OffsetCommit}, group_id_{group_id}, - group_generation_id_{group_generation_id}, member_id_{member_id}, - retention_time_{retention_time}, topics_{topics} {}; - - bool operator==(const OffsetCommitRequest& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && - group_generation_id_ == rhs.group_generation_id_ && member_id_ == rhs.member_id_ && - retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - if (encoder.apiVersion() >= 1) { - written += encoder.encode(group_generation_id_, dst); - written += encoder.encode(member_id_, dst); - } - if (encoder.apiVersion() >= 2) { - written += encoder.encode(retention_time_, dst); - } - written += encoder.encode(topics_, dst); - return written; - } - -private: - const std::string group_id_; - const int32_t group_generation_id_; // since v1 - const std::string member_id_; // since v1 - const int64_t retention_time_; // since v2 - const NullableArray topics_; -}; - -// api version 0 - -// Deserializes bytes into OffsetCommitPartition (api version 0): partition, offset, metadata -class OffsetCommitPartitionV0Buffer - : public CompositeDeserializerWith3Delegates {}; -// Deserializes array of OffsetCommitPartition-s v0 -class OffsetCommitPartitionV0ArrayBuffer - : public ArrayDeserializer {}; -// Deserializes bytes into OffsetCommitTopic (api version 0): topic name, partitions (v0) -class OffsetCommitTopicV0Buffer - : public CompositeDeserializerWith2Delegates {}; -// Deserializes array of OffsetCommitTopic-s v0 -class OffsetCommitTopicV0ArrayBuffer - : public ArrayDeserializer {}; -// Deserializes bytes into OffsetCommitRequest (api version 0): group_id, topics (v0) -class OffsetCommitRequestV0Deserializer - : public CompositeDeserializerWith2Delegates {}; - -// api version 1 - -// Deserializes bytes into OffsetCommitPartition (api version 1): partition, offset, timestamp, -// metadata -class OffsetCommitPartitionV1Buffer - : public CompositeDeserializerWith4Delegates {}; -// Deserializes array of OffsetCommitPartition-s v1 -class OffsetCommitPartitionV1ArrayBuffer - : public ArrayDeserializer {}; -// Deserializes bytes into OffsetCommitTopic (api version 1): topic name, partitions (v1) -class OffsetCommitTopicV1Buffer - : public CompositeDeserializerWith2Delegates {}; -// Deserializes array of OffsetCommitTopic-s v1 -class OffsetCommitTopicV1ArrayBuffer - : public ArrayDeserializer {}; -// Deserializes bytes into OffsetCommitRequest (api version 1): group_id, generation_id, member_id, -// topics (v1) -class OffsetCommitRequestV1Deserializer - : public CompositeDeserializerWith4Delegates {}; - -/** - * Define Parsers that wrap the corresponding deserializers - */ - -DEFINE_REQUEST_PARSER(OffsetCommitRequest, V0); -DEFINE_REQUEST_PARSER(OffsetCommitRequest, V1); - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index b51053337a7cb..9ef3a54471758 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -3,8 +3,6 @@ #include "common/buffer/buffer_impl.h" #include "common/common/stack_array.h" -#include "extensions/filters/network/kafka/kafka_protocol.h" - namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -56,10 +54,7 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl } void RequestEncoder::encode(const Request& message) { - // XXX (adam.kotwasinski) theoretically this context could be generated inside Request::encode (as - // the requested knows the api_version), but the serialization design is still to be discussed - // (explicit classes vs vectors of pointers vs templates) - EncodingContext encoder{message.apiVersion()}; + EncodingContext encoder; Buffer::OwnedImpl data_buffer; // TODO (adam.kotwasinski) precompute the size instead of using temporary // also, when we have 'computeSize' method, then we can push encoding request's size into diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index f7931b8a5bb35..65c2f36e51590 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -46,7 +46,7 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable callbacks) : parser_resolver_{parserResolver}, callbacks_{callbacks}, current_parser_{new RequestStartParser(parser_resolver_)} {}; @@ -62,7 +62,7 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable callbacks_; ParserSharedPtr current_parser_; diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index bb99e87f47f53..05a8f1e4ae0da 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -517,15 +517,9 @@ class ArrayDeserializer : public Deserializer> { * Encodes provided argument in Kafka format * In case of primitive types, this is done explicitly as per spec * In case of composite types, this is done by calling 'encode' on provided argument - * - * This object also carries extra information that is used while traversing the request - * structure-tree during encryping (currently api_version, as different request versions serialize - * differently) */ class EncodingContext { public: - EncodingContext(int16_t api_version) : api_version_{api_version} {}; - /** * Encode given reference in a buffer * @return bytes written @@ -537,11 +531,6 @@ class EncodingContext { * @return bytes written */ template size_t encode(const NullableArray& arg, Buffer::Instance& dst); - - int16_t apiVersion() const { return api_version_; } - -private: - const int16_t api_version_; }; /** diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h deleted file mode 100644 index 14f96ae8a1ad2..0000000000000 --- a/source/extensions/filters/network/kafka/serialization_composite.h +++ /dev/null @@ -1,135 +0,0 @@ -#pragma once - -#include -#include -#include -#include - -#include "envoy/buffer/buffer.h" -#include "envoy/common/exception.h" -#include "envoy/common/pure.h" - -#include "common/common/byte_order.h" -#include "common/common/fmt.h" - -#include "extensions/filters/network/kafka/kafka_types.h" -#include "extensions/filters/network/kafka/serialization.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * This header contains only composite deserializers - * The basic design is composite deserializer creating delegates DeserializerType1..Tn - * Result of type ResponseType is constructed by getting results of each of delegates - */ - -/** - * Composite deserializer that uses 2 deserializers - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) - * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) - */ -template -class CompositeDeserializerWith2Delegates : public Deserializer { -public: - CompositeDeserializerWith2Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return delegate2_.ready(); } - ResponseType get() const { return {delegate1_.get(), delegate2_.get()}; } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; -}; - -/** - * Composite deserializer that uses 3 deserializers - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) - * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) - * @param DeserializerType3 3rd deserializer (result used as 3rd argument of ResponseType's ctor) - */ -template -class CompositeDeserializerWith3Delegates : public Deserializer { -public: - CompositeDeserializerWith3Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return delegate3_.ready(); } - ResponseType get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; -}; - -/** - * Composite deserializer that uses 4 deserializers - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 1st deserializer (result used as 1st argument of ResponseType's ctor) - * @param DeserializerType2 2nd deserializer (result used as 2nd argument of ResponseType's ctor) - * @param DeserializerType3 3rd deserializer (result used as 3rd argument of ResponseType's ctor) - * @param DeserializerType4 4th deserializer (result used as 4th argument of ResponseType's ctor) - */ -template -class CompositeDeserializerWith4Delegates : public Deserializer { -public: - CompositeDeserializerWith4Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - return consumed; - } - bool ready() const { return delegate4_.ready(); } - ResponseType get() const { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; -}; - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index ab85bf20aaced..cc4ad6c12bcac 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -21,6 +21,16 @@ envoy_extension_cc_test( ], ) +envoy_extension_cc_test( + name = "serialization_composite_test", + srcs = ["generated/serialization_composite_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:serialization_lib", + "//test/mocks/server:server_mocks", + ], +) + envoy_extension_cc_test( name = "kafka_request_test", srcs = ["kafka_request_test.cc"], @@ -40,3 +50,13 @@ envoy_extension_cc_test( "//test/mocks/server:server_mocks", ], ) + +envoy_extension_cc_test( + name = "requests_test", + srcs = ["generated/requests_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) diff --git a/test/extensions/filters/network/kafka/generated/requests_test.cc b/test/extensions/filters/network/kafka/generated/requests_test.cc new file mode 100644 index 0000000000000..43a7b08599b86 --- /dev/null +++ b/test/extensions/filters/network/kafka/generated/requests_test.cc @@ -0,0 +1,95 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#include "extensions/filters/network/kafka/generated/requests.h" +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class RequestDecoderTest : public testing::Test { +public: + Buffer::OwnedImpl buffer_; + + template std::shared_ptr serializeAndDeserialize(T request); +}; + +class MockMessageListener : public RequestCallback { +public: + MOCK_METHOD1(onMessage, void(MessageSharedPtr)); +}; + +template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { + RequestEncoder serializer{buffer_}; + serializer.encode(request); + + std::shared_ptr mock_listener = std::make_shared(); + RequestDecoder testee{RequestParserResolver::INSTANCE, {mock_listener}}; + + MessageSharedPtr receivedMessage; + EXPECT_CALL(*mock_listener, onMessage(testing::_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); + + testee.onData(buffer_); + + return std::dynamic_pointer_cast(receivedMessage); +}; + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { + // given + OffsetCommitRequestV0 request = {"string", {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { + // given + OffsetCommitRequestV1 request = {"string", 32, "string", {{ {"string", {{ {32, 64, 64, {"nullable"}, } }}, } }}, }; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV2) { + // given + OffsetCommitRequestV2 request = {"string", 32, "string", 64, {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV3) { + // given + OffsetCommitRequestV3 request = {"string", 32, "string", 64, {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; + + // when + auto received = serializeAndDeserialize(request); + + // then + ASSERT_NE(received, nullptr); + ASSERT_EQ(*received, request); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy +// clang-format on diff --git a/test/extensions/filters/network/kafka/generated/serialization_composite_test.cc b/test/extensions/filters/network/kafka/generated/serialization_composite_test.cc new file mode 100644 index 0000000000000..d310bf4530e15 --- /dev/null +++ b/test/extensions/filters/network/kafka/generated/serialization_composite_test.cc @@ -0,0 +1,484 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#include "common/common/stack_array.h" + +#include "extensions/filters/network/kafka/generated/serialization_composite.h" +#include "extensions/filters/network/kafka/serialization.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Tests in this class are supposed to check whether serialization operations + * on composite deserializers are behaving correctly + */ + +// helper function +const char* getRawData(const Buffer::OwnedImpl& buffer) { + uint64_t num_slices = buffer.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); +} + +// exactly what is says on the tin: +// 1. serialize expected using Encoder +// 2. deserialize byte array using testee deserializer +// 3. verify result = expected +// 4. verify that data pointer moved correct amount +// 5. feed testee more data +// 6. verify that nothing more was consumed +template +void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder; + const size_t written = encoder.encode(expected, buffer); + + uint64_t remaining = + 10 * + written; // tell parser that there is more data, it should never consume more than written + const uint64_t orig_remaining = remaining; + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + const size_t consumed = testee.feed(data, remaining); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, orig_remaining - consumed); +} + +// does the same thing as the above test, +// but instead of providing whole data at one, it provides it in N one-byte chunks +// this verifies if deserializer keeps state properly (no overwrites etc.) +template +void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder; + const size_t written = encoder.encode(expected, buffer); + + const char* data = getRawData(buffer); + const char* orig_data = data; + + // when + size_t consumed = 0; + for (size_t i = 0; i < written; ++i) { + uint64_t data_size = 1; + consumed += testee.feed(data, data_size); + ASSERT_EQ(data_size, 0); + } + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + ASSERT_EQ(data, orig_data + consumed); + + // when - 2 + uint64_t remaining = 1024; + const size_t consumed2 = testee.feed(data, remaining); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(data, orig_data + consumed); + ASSERT_EQ(remaining, 1024); +} + +// wrapper to run both tests +template void serializeThenDeserializeAndCheckEquality(AT expected) { + serializeThenDeserializeAndCheckEqualityInOneGo(expected); + serializeThenDeserializeAndCheckEqualityWithChunks(expected); +} + +// tests for composite deserializers + +struct CompositeResultWith0Fields { + + size_t encode(Buffer::Instance&, EncodingContext&) const { + return 0; + } + + bool operator==(const CompositeResultWith0Fields&) const { + return true; + } +}; + +typedef CompositeDeserializerWith0Delegates + TestCompositeDeserializer0; + +// composite with 0 delegates is special case: it's always ready +TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { + // given + const TestCompositeDeserializer0 testee{}; + // when, then + ASSERT_EQ(testee.ready(), true); +} + +TEST(CompositeDeserializerWith0Delegates, ShouldDeserialize) { + const CompositeResultWith0Fields expected{}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith1Fields { + const std::string field1_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + return written; + } + + bool operator==(const CompositeResultWith1Fields& rhs) const { + return field1_ == rhs.field1_; + } +}; + +typedef CompositeDeserializerWith1Delegates + TestCompositeDeserializer1; + +TEST(CompositeDeserializerWith1Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer1 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith1Delegates, ShouldDeserialize) { + const CompositeResultWith1Fields expected{"s1"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith2Fields { + const std::string field1_; + const std::string field2_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + return written; + } + + bool operator==(const CompositeResultWith2Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_; + } +}; + +typedef CompositeDeserializerWith2Delegates + TestCompositeDeserializer2; + +TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer2 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith2Delegates, ShouldDeserialize) { + const CompositeResultWith2Fields expected{"s1", "s2"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith3Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + return written; + } + + bool operator==(const CompositeResultWith3Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_; + } +}; + +typedef CompositeDeserializerWith3Delegates + TestCompositeDeserializer3; + +TEST(CompositeDeserializerWith3Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer3 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith3Delegates, ShouldDeserialize) { + const CompositeResultWith3Fields expected{"s1", "s2", "s3"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith4Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + return written; + } + + bool operator==(const CompositeResultWith4Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_; + } +}; + +typedef CompositeDeserializerWith4Delegates + TestCompositeDeserializer4; + +TEST(CompositeDeserializerWith4Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer4 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith4Delegates, ShouldDeserialize) { + const CompositeResultWith4Fields expected{"s1", "s2", "s3", "s4"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith5Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + const std::string field5_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + written += encoder.encode(field5_, dst); + return written; + } + + bool operator==(const CompositeResultWith5Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_; + } +}; + +typedef CompositeDeserializerWith5Delegates + TestCompositeDeserializer5; + +TEST(CompositeDeserializerWith5Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer5 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith5Delegates, ShouldDeserialize) { + const CompositeResultWith5Fields expected{"s1", "s2", "s3", "s4", "s5"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith6Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + const std::string field5_; + const std::string field6_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + written += encoder.encode(field5_, dst); + written += encoder.encode(field6_, dst); + return written; + } + + bool operator==(const CompositeResultWith6Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_; + } +}; + +typedef CompositeDeserializerWith6Delegates + TestCompositeDeserializer6; + +TEST(CompositeDeserializerWith6Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer6 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith6Delegates, ShouldDeserialize) { + const CompositeResultWith6Fields expected{"s1", "s2", "s3", "s4", "s5", "s6"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith7Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + const std::string field5_; + const std::string field6_; + const std::string field7_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + written += encoder.encode(field5_, dst); + written += encoder.encode(field6_, dst); + written += encoder.encode(field7_, dst); + return written; + } + + bool operator==(const CompositeResultWith7Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_; + } +}; + +typedef CompositeDeserializerWith7Delegates + TestCompositeDeserializer7; + +TEST(CompositeDeserializerWith7Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer7 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith7Delegates, ShouldDeserialize) { + const CompositeResultWith7Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith8Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + const std::string field5_; + const std::string field6_; + const std::string field7_; + const std::string field8_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + written += encoder.encode(field5_, dst); + written += encoder.encode(field6_, dst); + written += encoder.encode(field7_, dst); + written += encoder.encode(field8_, dst); + return written; + } + + bool operator==(const CompositeResultWith8Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_ && field8_ == rhs.field8_; + } +}; + +typedef CompositeDeserializerWith8Delegates + TestCompositeDeserializer8; + +TEST(CompositeDeserializerWith8Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer8 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith8Delegates, ShouldDeserialize) { + const CompositeResultWith8Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +struct CompositeResultWith9Fields { + const std::string field1_; + const std::string field2_; + const std::string field3_; + const std::string field4_; + const std::string field5_; + const std::string field6_; + const std::string field7_; + const std::string field8_; + const std::string field9_; + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, dst); + written += encoder.encode(field2_, dst); + written += encoder.encode(field3_, dst); + written += encoder.encode(field4_, dst); + written += encoder.encode(field5_, dst); + written += encoder.encode(field6_, dst); + written += encoder.encode(field7_, dst); + written += encoder.encode(field8_, dst); + written += encoder.encode(field9_, dst); + return written; + } + + bool operator==(const CompositeResultWith9Fields& rhs) const { + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_ && field8_ == rhs.field8_ && field9_ == rhs.field9_; + } +}; + +typedef CompositeDeserializerWith9Delegates + TestCompositeDeserializer9; + +TEST(CompositeDeserializerWith9Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer9 testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith9Delegates, ShouldDeserialize) { + const CompositeResultWith9Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "s9"}; + serializeThenDeserializeAndCheckEquality(expected); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy +// clang-format on diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc index 268c6eda29678..a78fd835b30d2 100644 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_test.cc @@ -1,7 +1,7 @@ #include "common/common/stack_array.h" +#include "extensions/filters/network/kafka/generated/requests.h" #include "extensions/filters/network/kafka/kafka_request.h" -#include "extensions/filters/network/kafka/messages/offset_commit.h" #include "test/mocks/server/mocks.h" @@ -16,52 +16,6 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -TEST(RequestParserResolver, ShouldReturnSentinelIfRequestTypeIsNotRegistered) { - // given - RequestParserResolver testee{{}}; - RequestContextSharedPtr context{new RequestContext{}}; - - // when - ParserSharedPtr result = testee.createParser(0, 1, context); // api_key = 0 was not registered - - // then - ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); -} - -TEST(RequestParserResolver, ShouldReturnSentinelIfRequestVersionIsNotRegistered) { - // given - ParserGeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); - }; - RequestParserResolver testee{{{0, {0, 1}, generator}}}; - RequestContextSharedPtr context{new RequestContext{}}; - - // when - ParserSharedPtr result = - testee.createParser(0, 2, context); // api_version = 2 was not registered (0 & 1 were) - - // then - ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); -} - -TEST(RequestParserResolver, ShouldInvokeGeneratorFunctionOnMatch) { - // given - ParserGeneratorFunction generator = [](RequestContextSharedPtr arg) -> ParserSharedPtr { - return std::make_shared(arg); - }; - RequestParserResolver testee{{{0, {0, 1, 2, 3}, generator}}}; - RequestContextSharedPtr context{new RequestContext{}}; - - // when - ParserSharedPtr result = testee.createParser(0, 3, context); - - // then - ASSERT_NE(result, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result), nullptr); -} - class BufferBasedTest : public testing::Test { public: Buffer::OwnedImpl& buffer() { return buffer_; } @@ -75,12 +29,19 @@ class BufferBasedTest : public testing::Test { protected: Buffer::OwnedImpl buffer_; - EncodingContext encoder_{-1}; + EncodingContext encoder_; +}; + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver(){}; + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); }; TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { // given - RequestStartParser testee{RequestParserResolver{{}}}; + MockRequestParserResolver resolver{}; + RequestStartParser testee{resolver}; int32_t request_len = 1234; encoder_.encode(request_len, buffer()); @@ -98,12 +59,6 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); } -class MockRequestParserResolver : public RequestParserResolver { -public: - MockRequestParserResolver() : RequestParserResolver{{}} {}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); -}; - TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { // given const MockRequestParserResolver parser_resolver; diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc index 3bd45bbb52375..0c408c9d32a0a 100644 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -1,4 +1,4 @@ -#include "extensions/filters/network/kafka/messages/offset_commit.h" +#include "extensions/filters/network/kafka/generated/requests.h" #include "extensions/filters/network/kafka/request_codec.h" #include "test/mocks/server/mocks.h" @@ -25,12 +25,18 @@ class MockMessageListener : public RequestCallback { MOCK_METHOD1(onMessage, void(MessageSharedPtr)); }; +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver() : RequestParserResolver({}){}; + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); +}; + template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { RequestEncoder serializer{buffer_}; serializer.encode(request); std::shared_ptr mock_listener = std::make_shared(); - RequestDecoder testee{RequestParserResolver::KAFKA_0_11, {mock_listener}}; + RequestDecoder testee{RequestParserResolver::INSTANCE, {mock_listener}}; MessageSharedPtr receivedMessage; EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); @@ -40,57 +46,25 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria return std::dynamic_pointer_cast(receivedMessage); }; -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { - // given - NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; - OffsetCommitRequest request{"group_id", topics}; - request.apiVersion() = 0; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { - // given - // partitions have timestamp in v1 only - NullableArray topics{ - {{"topic1", {{{0, 10, 100, "m1"}, {2, 20, 101, "m2"}}}}, {"topic2", {{{3, 30, 102, "m3"}}}}}}; - OffsetCommitRequest request{"group_id", - 40, // group_generation_id - "member_id", // member_id - topics}; - request.apiVersion() = 1; - request.correlationId() = 10; - request.clientId() = "client-id"; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); +ParserSharedPtr createSentinelParser(testing::Unused, testing::Unused, + RequestContextSharedPtr context) { + return std::make_shared(context); } TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given RequestEncoder serializer{buffer_}; - NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; - OffsetCommitRequest request{"group_id", topics}; - request.apiVersion() = 1; - request.correlationId() = 42; - request.clientId() = "client-id"; + NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; + OffsetCommitRequestV0 request{"group_id", topics}; + request.setMetadata(42, "client-id"); serializer.encode(request); + MockRequestParserResolver mock_parser_resolver{}; + EXPECT_CALL(mock_parser_resolver, createParser(_, _, _)) + .WillOnce(testing::Invoke(createSentinelParser)); std::shared_ptr mock_listener = std::make_shared(); - RequestParserResolver parser_resolver{{}}; // we do not accept any kind of message here - RequestDecoder testee{parser_resolver, {mock_listener}}; + RequestDecoder testee{mock_parser_resolver, {mock_listener}}; MessageSharedPtr rev; EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&rev)); @@ -99,6 +73,7 @@ TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { testee.onData(buffer_); // then + ASSERT_NE(rev, nullptr); auto received = std::dynamic_pointer_cast(rev); ASSERT_NE(received, nullptr); } diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 65655ec7a460c..a684e7bc649f2 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,15 +1,13 @@ #include "common/common/stack_array.h" +#include "extensions/filters/network/kafka/generated/serialization_composite.h" #include "extensions/filters/network/kafka/serialization.h" -#include "extensions/filters/network/kafka/serialization_composite.h" #include "test/mocks/server/mocks.h" #include "gmock/gmock.h" #include "gtest/gtest.h" -using testing::_; - namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -39,38 +37,6 @@ TEST_EmptyDeserializerShouldNotBeReady(NullableStringDeserializer); TEST_EmptyDeserializerShouldNotBeReady(BytesDeserializer); TEST_EmptyDeserializerShouldNotBeReady(NullableBytesDeserializer); -TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { - // given - struct CompositeResult { - CompositeResult(int8_t, int16_t){}; - }; - const CompositeDeserializerWith2Delegates - testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} -TEST(CompositeDeserializerWith3Delegates, EmptyBufferShouldNotBeReady) { - // given - struct CompositeResult { - CompositeResult(int8_t, int16_t, int32_t){}; - }; - const CompositeDeserializerWith3Delegates - testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} -TEST(CompositeDeserializerWith4Delegates, EmptyBufferShouldNotBeReady) { - // given - struct CompositeResult { - CompositeResult(int8_t, int16_t, int32_t, std::string){}; - }; - const CompositeDeserializerWith4Delegates - testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { // given const ArrayDeserializer testee{}; @@ -78,7 +44,7 @@ TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { ASSERT_EQ(testee.ready(), false); } -EncodingContext encoder{-1}; // context is not used when serializing primitive types +EncodingContext encoder; // helper function const char* getRawData(const Buffer::OwnedImpl& buffer) { @@ -327,92 +293,6 @@ TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data, remaining), EnvoyException); } -// tests for composite deserializers - -struct CompositeResultWith2Fields { - std::string field1_; - NullableArray field2_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - return written; - } - - bool operator==(const CompositeResultWith2Fields& rhs) const { - return (field1_ == rhs.field1_) && (field2_ == rhs.field2_); - } -}; - -struct CompositeResultWith3Fields { - std::string field1_; - NullableArray field2_; - int16_t field3_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - return written; - } - - bool operator==(const CompositeResultWith3Fields& rhs) const { - return (field1_ == rhs.field1_) && (field2_ == rhs.field2_) && (field3_ == rhs.field3_); - } -}; - -struct CompositeResultWith4Fields { - std::string field1_; - NullableArray field2_; - int16_t field3_; - std::string field4_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - return written; - } - - bool operator==(const CompositeResultWith4Fields& rhs) const { - return (field1_ == rhs.field1_) && (field2_ == rhs.field2_) && (field3_ == rhs.field3_) && - (field4_ == rhs.field4_); - } -}; - -typedef CompositeDeserializerWith2Delegates> - TestCompositeDeserializer2; - -typedef CompositeDeserializerWith3Delegates, - Int16Deserializer> - TestCompositeDeserializer3; - -typedef CompositeDeserializerWith4Delegates, - Int16Deserializer, StringDeserializer> - TestCompositeDeserializer4; - -TEST(CompositeDeserializerWith2Delegates, ShouldDeserialize) { - const CompositeResultWith2Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}}; - serializeThenDeserializeAndCheckEquality(expected); -} - -TEST(CompositeDeserializerWith3Delegates, ShouldDeserialize) { - const CompositeResultWith3Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234}; - serializeThenDeserializeAndCheckEquality(expected); -} - -TEST(CompositeDeserializerWith4Delegates, ShouldDeserialize) { - const CompositeResultWith4Fields expected{"zzzzz", {{10, 20, 30, 40, 50}}, 1234, "aaa"}; - serializeThenDeserializeAndCheckEquality(expected); -} - } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions From db7763e7e44716a190f7657cef3450a4e87bda29 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 27 Nov 2018 15:44:28 +0000 Subject: [PATCH 11/29] Add python tool for generating Kafka request classes; keep all versions of request in a single structure Signed-off-by: Adam Kotwasinski --- api/bazel/repositories.bzl | 11 + source/extensions/filters/network/kafka/BUILD | 36 +- .../extensions/filters/network/kafka/codec.h | 11 +- .../network/kafka/generated/requests.h | 448 ------------------ .../filters/network/kafka/kafka_request.h | 194 +------- ...fka_request.cc => kafka_request_parser.cc} | 21 +- .../network/kafka/kafka_request_parser.h | 204 ++++++++ .../filters/network/kafka/message.h | 7 + .../complex_type_template.j2 | 37 ++ .../kafka_generator.py | 402 ++++++++++++++++ .../kafka_request_resolver_cc.j2} | 20 +- .../protocol_code_generator/request_parser.j2 | 6 + .../protocol_code_generator/requests_h.j2 | 15 + .../requests_test_cc.j2 | 48 +- .../filters/network/kafka/request_codec.cc | 10 +- .../filters/network/kafka/request_codec.h | 12 +- .../filters/network/kafka/serialization.h | 145 +++++- .../{generated => }/serialization_composite.h | 246 ++-------- test/extensions/filters/network/kafka/BUILD | 21 +- .../kafka/kafka_request_parser_test.cc | 249 ++++++++++ .../network/kafka/kafka_request_test.cc | 131 ----- .../network/kafka/request_codec_test.cc | 27 +- .../serialization_composite_test.cc | 79 +-- .../network/kafka/serialization_test.cc | 35 +- 24 files changed, 1330 insertions(+), 1085 deletions(-) delete mode 100644 source/extensions/filters/network/kafka/generated/requests.h rename source/extensions/filters/network/kafka/{kafka_request.cc => kafka_request_parser.cc} (70%) create mode 100644 source/extensions/filters/network/kafka/kafka_request_parser.h create mode 100644 source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 create mode 100755 source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py rename source/extensions/filters/network/kafka/{generated/kafka_request_resolver.cc => protocol_code_generator/kafka_request_resolver_cc.j2} (51%) create mode 100644 source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 create mode 100644 source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 rename test/extensions/filters/network/kafka/generated/requests_test.cc => source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 (52%) rename source/extensions/filters/network/kafka/{generated => }/serialization_composite.h (80%) create mode 100644 test/extensions/filters/network/kafka/kafka_request_parser_test.cc delete mode 100644 test/extensions/filters/network/kafka/kafka_request_test.cc rename test/extensions/filters/network/kafka/{generated => }/serialization_composite_test.cc (83%) diff --git a/api/bazel/repositories.bzl b/api/bazel/repositories.bzl index 8afc03dcbb725..c5805886a80f4 100644 --- a/api/bazel/repositories.bzl +++ b/api/bazel/repositories.bzl @@ -16,6 +16,7 @@ PROMETHEUS_SHA = "783bdaf8ee0464b35ec0c8704871e1e72afa0005c3f3587f65d9d6694bf391 OPENCENSUS_GIT_SHA = "ab82e5fdec8267dc2a726544b10af97675970847" # May 23, 2018 OPENCENSUS_SHA = "1950f844d9f338ba731897a9bb526f9074c0487b3f274ce2ec3b4feaf0bef7e2" +load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") def api_dependencies(): @@ -31,6 +32,16 @@ def api_dependencies(): strip_prefix = "protoc-gen-validate-" + PGV_RELEASE, url = "https://github.com/lyft/protoc-gen-validate/archive/v" + PGV_RELEASE + ".tar.gz", ) + http_file( + name = "kafka_produce_request_spec", + sha256 = "e035f70a136ef5a5ef2ff17b52dc10f2eae4ac596639689f5584054909d5816f", + urls = ["https://raw.githubusercontent.com/apache/kafka/2.2/clients/src/main/resources/common/message/ProduceRequest.json"], + ) + http_file( + name = "kafka_fetch_request_spec", + sha256 = "9209b68fe0295818071c2f644363cf71e6443eb61f8e9d2636412876c5e2bae8", + urls = ["https://raw.githubusercontent.com/apache/kafka/2.2/clients/src/main/resources/common/message/FetchRequest.json"], + ) http_archive( name = "googleapis", strip_prefix = "googleapis-" + GOOGLEAPIS_GIT_SHA, diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index e4c73e532b025..7a6387ff3a0b8 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -20,6 +20,7 @@ envoy_cc_library( ], deps = [ ":kafka_request_lib", + ":message_lib", "//source/common/buffer:buffer_lib", ], ) @@ -27,12 +28,13 @@ envoy_cc_library( envoy_cc_library( name = "kafka_request_lib", srcs = [ - "generated/kafka_request_resolver.cc", - "kafka_request.cc", + "kafka_request_parser.cc", + "kafka_request_resolver.cc", ], hdrs = [ - "generated/requests.h", "kafka_request.h", + "kafka_request_parser.h", + "requests.h", ], deps = [ ":parser_lib", @@ -42,6 +44,30 @@ envoy_cc_library( ], ) +genrule( + name = "kafka_generated_source", + srcs = [ + "@kafka_produce_request_spec//file", + "@kafka_fetch_request_spec//file", + ], + outs = [ + "requests.h", + "kafka_request_resolver.cc", + ], + cmd = "./$(location :kafka_code_generator) generate-source $(location requests.h) $(location kafka_request_resolver.cc) $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file)", + tools = [ + ":kafka_code_generator", + ], +) + +py_binary( + name = "kafka_code_generator", + srcs = ["protocol_code_generator/kafka_generator.py"], + data = glob(["protocol_code_generator/*.j2"]), + main = "protocol_code_generator/kafka_generator.py", + deps = ["@com_github_pallets_jinja//:jinja2"], +) + envoy_cc_library( name = "parser_lib", hdrs = ["parser.h"], @@ -57,15 +83,13 @@ envoy_cc_library( hdrs = [ "message.h", ], - deps = [ - ], ) envoy_cc_library( name = "serialization_lib", hdrs = [ - "generated/serialization_composite.h", "serialization.h", + "serialization_composite.h", ], deps = [ ":kafka_protocol_lib", diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index a68d798ac400e..cfbd1c5337480 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -3,6 +3,8 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" +#include "extensions/filters/network/kafka/message.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -10,14 +12,13 @@ namespace Kafka { /** * Kafka message decoder - * @tparam MessageType message type (Kafka request or Kafka response) */ -template class MessageDecoder { +class MessageDecoder { public: virtual ~MessageDecoder() = default; /** - * Processes given buffer attempting to decode messages of type MessageType container within + * Processes given buffer attempting to decode messages contained within * @param data buffer instance */ virtual void onData(Buffer::Instance& data) PURE; @@ -27,7 +28,7 @@ template class MessageDecoder { * Kafka message decoder * @tparam MessageType message type (Kafka request or Kafka response) */ -template class MessageEncoder { +class MessageEncoder { public: virtual ~MessageEncoder() = default; @@ -35,7 +36,7 @@ template class MessageEncoder { * Encodes given message * @param message message to be encoded */ - virtual void encode(const MessageType& message) PURE; + virtual void encode(const Message& message) PURE; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/generated/requests.h b/source/extensions/filters/network/kafka/generated/requests.h deleted file mode 100644 index 566f3ea012a8b..0000000000000 --- a/source/extensions/filters/network/kafka/generated/requests.h +++ /dev/null @@ -1,448 +0,0 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off -#pragma once -#include "extensions/filters/network/kafka/kafka_request.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/* Represents 'partitions' element in OffsetCommitRequestV0 */ -struct OffsetCommitRequestV0Partition { - const int32_t partition_; - const int64_t offset_; - const NullableString metadata_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV0Partition& rhs) const { - return - partition_ == rhs.partition_ && - offset_ == rhs.offset_ && - metadata_ == rhs.metadata_; - }; - -}; - -class OffsetCommitRequestV0PartitionDeserializer: - public CompositeDeserializerWith3Delegates< - OffsetCommitRequestV0Partition, - Int32Deserializer, - Int64Deserializer, - NullableStringDeserializer - >{}; - -/* Represents 'topics' element in OffsetCommitRequestV0 */ -struct OffsetCommitRequestV0Topic { - const std::string topic_; - const NullableArray partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV0Topic& rhs) const { - return - topic_ == rhs.topic_ && - partitions_ == rhs.partitions_; - }; - -}; - -class OffsetCommitRequestV0TopicDeserializer: - public CompositeDeserializerWith2Delegates< - OffsetCommitRequestV0Topic, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV0 : public Request { -public: - OffsetCommitRequestV0( - std::string group_id, - NullableArray topics - ): - Request{8, 0}, - group_id_{group_id}, - topics_{topics} - {}; - - bool operator==(const OffsetCommitRequestV0& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - written += encoder.encode(topics_, dst); - return written; - } - -private: - const std::string group_id_; - const NullableArray topics_; -}; - -class OffsetCommitRequestV0Deserializer: - public CompositeDeserializerWith2Delegates< - OffsetCommitRequestV0, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV0Parser : public RequestParser { -public: - OffsetCommitRequestV0Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; -}; - -/* Represents 'partitions' element in OffsetCommitRequestV1 */ -struct OffsetCommitRequestV1Partition { - const int32_t partition_; - const int64_t offset_; - const int64_t timestamp_; - const NullableString metadata_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - written += encoder.encode(timestamp_, dst); - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV1Partition& rhs) const { - return - partition_ == rhs.partition_ && - offset_ == rhs.offset_ && - timestamp_ == rhs.timestamp_ && - metadata_ == rhs.metadata_; - }; - -}; - -class OffsetCommitRequestV1PartitionDeserializer: - public CompositeDeserializerWith4Delegates< - OffsetCommitRequestV1Partition, - Int32Deserializer, - Int64Deserializer, - Int64Deserializer, - NullableStringDeserializer - >{}; - -/* Represents 'topics' element in OffsetCommitRequestV1 */ -struct OffsetCommitRequestV1Topic { - const std::string topic_; - const NullableArray partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV1Topic& rhs) const { - return - topic_ == rhs.topic_ && - partitions_ == rhs.partitions_; - }; - -}; - -class OffsetCommitRequestV1TopicDeserializer: - public CompositeDeserializerWith2Delegates< - OffsetCommitRequestV1Topic, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV1 : public Request { -public: - OffsetCommitRequestV1( - std::string group_id, - int32_t generation_id, - std::string member_id, - NullableArray topics - ): - Request{8, 1}, - group_id_{group_id}, - generation_id_{generation_id}, - member_id_{member_id}, - topics_{topics} - {}; - - bool operator==(const OffsetCommitRequestV1& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - written += encoder.encode(generation_id_, dst); - written += encoder.encode(member_id_, dst); - written += encoder.encode(topics_, dst); - return written; - } - -private: - const std::string group_id_; - const int32_t generation_id_; - const std::string member_id_; - const NullableArray topics_; -}; - -class OffsetCommitRequestV1Deserializer: - public CompositeDeserializerWith4Delegates< - OffsetCommitRequestV1, - StringDeserializer, - Int32Deserializer, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV1Parser : public RequestParser { -public: - OffsetCommitRequestV1Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; -}; - -/* Represents 'partitions' element in OffsetCommitRequestV2 */ -struct OffsetCommitRequestV2Partition { - const int32_t partition_; - const int64_t offset_; - const NullableString metadata_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV2Partition& rhs) const { - return - partition_ == rhs.partition_ && - offset_ == rhs.offset_ && - metadata_ == rhs.metadata_; - }; - -}; - -class OffsetCommitRequestV2PartitionDeserializer: - public CompositeDeserializerWith3Delegates< - OffsetCommitRequestV2Partition, - Int32Deserializer, - Int64Deserializer, - NullableStringDeserializer - >{}; - -/* Represents 'topics' element in OffsetCommitRequestV2 */ -struct OffsetCommitRequestV2Topic { - const std::string topic_; - const NullableArray partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV2Topic& rhs) const { - return - topic_ == rhs.topic_ && - partitions_ == rhs.partitions_; - }; - -}; - -class OffsetCommitRequestV2TopicDeserializer: - public CompositeDeserializerWith2Delegates< - OffsetCommitRequestV2Topic, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV2 : public Request { -public: - OffsetCommitRequestV2( - std::string group_id, - int32_t generation_id, - std::string member_id, - int64_t retention_time, - NullableArray topics - ): - Request{8, 2}, - group_id_{group_id}, - generation_id_{generation_id}, - member_id_{member_id}, - retention_time_{retention_time}, - topics_{topics} - {}; - - bool operator==(const OffsetCommitRequestV2& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - written += encoder.encode(generation_id_, dst); - written += encoder.encode(member_id_, dst); - written += encoder.encode(retention_time_, dst); - written += encoder.encode(topics_, dst); - return written; - } - -private: - const std::string group_id_; - const int32_t generation_id_; - const std::string member_id_; - const int64_t retention_time_; - const NullableArray topics_; -}; - -class OffsetCommitRequestV2Deserializer: - public CompositeDeserializerWith5Delegates< - OffsetCommitRequestV2, - StringDeserializer, - Int32Deserializer, - StringDeserializer, - Int64Deserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV2Parser : public RequestParser { -public: - OffsetCommitRequestV2Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; -}; - -/* Represents 'partitions' element in OffsetCommitRequestV3 */ -struct OffsetCommitRequestV3Partition { - const int32_t partition_; - const int64_t offset_; - const NullableString metadata_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(partition_, dst); - written += encoder.encode(offset_, dst); - written += encoder.encode(metadata_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV3Partition& rhs) const { - return - partition_ == rhs.partition_ && - offset_ == rhs.offset_ && - metadata_ == rhs.metadata_; - }; - -}; - -class OffsetCommitRequestV3PartitionDeserializer: - public CompositeDeserializerWith3Delegates< - OffsetCommitRequestV3Partition, - Int32Deserializer, - Int64Deserializer, - NullableStringDeserializer - >{}; - -/* Represents 'topics' element in OffsetCommitRequestV3 */ -struct OffsetCommitRequestV3Topic { - const std::string topic_; - const NullableArray partitions_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(topic_, dst); - written += encoder.encode(partitions_, dst); - return written; - } - - bool operator==(const OffsetCommitRequestV3Topic& rhs) const { - return - topic_ == rhs.topic_ && - partitions_ == rhs.partitions_; - }; - -}; - -class OffsetCommitRequestV3TopicDeserializer: - public CompositeDeserializerWith2Delegates< - OffsetCommitRequestV3Topic, - StringDeserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV3 : public Request { -public: - OffsetCommitRequestV3( - std::string group_id, - int32_t generation_id, - std::string member_id, - int64_t retention_time, - NullableArray topics - ): - Request{8, 3}, - group_id_{group_id}, - generation_id_{generation_id}, - member_id_{member_id}, - retention_time_{retention_time}, - topics_{topics} - {}; - - bool operator==(const OffsetCommitRequestV3& rhs) const { - return request_header_ == rhs.request_header_ && group_id_ == rhs.group_id_ && generation_id_ == rhs.generation_id_ && member_id_ == rhs.member_id_ && retention_time_ == rhs.retention_time_ && topics_ == rhs.topics_; - }; - -protected: - size_t encodeDetails(Buffer::Instance& dst, EncodingContext& encoder) const override { - size_t written{0}; - written += encoder.encode(group_id_, dst); - written += encoder.encode(generation_id_, dst); - written += encoder.encode(member_id_, dst); - written += encoder.encode(retention_time_, dst); - written += encoder.encode(topics_, dst); - return written; - } - -private: - const std::string group_id_; - const int32_t generation_id_; - const std::string member_id_; - const int64_t retention_time_; - const NullableArray topics_; -}; - -class OffsetCommitRequestV3Deserializer: - public CompositeDeserializerWith5Delegates< - OffsetCommitRequestV3, - StringDeserializer, - Int32Deserializer, - StringDeserializer, - Int64Deserializer, - ArrayDeserializer - >{}; - -class OffsetCommitRequestV3Parser : public RequestParser { -public: - OffsetCommitRequestV3Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; -}; - -}}}} -// clang-format on diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index aff49af69b43a..fbd844e695e0b 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -4,11 +4,9 @@ #include "envoy/common/exception.h" -#include "common/common/assert.h" - -#include "extensions/filters/network/kafka/generated/serialization_composite.h" -#include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/message.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/serialization_composite.h" namespace Envoy { namespace Extensions { @@ -32,157 +30,34 @@ struct RequestHeader { }; /** - * Context that is shared between parsers that are handling the same single message - */ -struct RequestContext { - int32_t remaining_request_size_{0}; - RequestHeader request_header_{}; -}; - -typedef std::shared_ptr RequestContextSharedPtr; - -/** - * Configuration object - * Resolves the parser that will be responsible for consuming the request-specific data - * In other words: provides (api_key, api_version) -> Parser function - */ -class RequestParserResolver { -public: - virtual ~RequestParserResolver() = default; - - /** - * Creates a parser that is going to process data specific for given api_key & api_version - * @param api_key request type - * @param api_version request version - * @param context context to be used by parser - * @return parser that is capable of processing data for given request type & version - */ - virtual ParserSharedPtr createParser(int16_t api_key, int16_t api_version, - RequestContextSharedPtr context) const; - - /** - * Request parser singleton - */ - static const RequestParserResolver INSTANCE; -}; - -/** - * Request parser responsible for consuming request length and setting up context with this data - * @see http://kafka.apache.org/protocol.html#protocol_common - */ -class RequestStartParser : public Parser { -public: - RequestStartParser(const RequestParserResolver& parser_resolver) - : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; - - /** - * Consumes INT32 bytes as request length and updates the context with that value - * @return RequestHeaderParser instance to process request header - */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; - - const RequestContextSharedPtr contextForTest() const { return context_; } - -private: - const RequestParserResolver& parser_resolver_; - const RequestContextSharedPtr context_; - Int32Deserializer request_length_; -}; - -/** - * Deserializer that extracts request header (4 fields) - * @see http://kafka.apache.org/protocol.html#protocol_messages - */ -class RequestHeaderDeserializer - : public CompositeDeserializerWith4Delegates {}; - -/** - * Parser responsible for computing request header and updating the context with data resolved - * On a successful parse uses resolved data (api_key & api_version) to determine next parser. + * Abstract Kafka request + * Contains data present in every request (the header with request key, version, etc.) * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class RequestHeaderParser : public Parser { -public: - RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) - : parser_resolver_{parser_resolver}, context_{context} {}; - - /** - * Uses data provided to compute request header - * @return Parser instance responsible for processing rest of the message - */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; - - const RequestContextSharedPtr contextForTest() const { return context_; } - -private: - const RequestParserResolver& parser_resolver_; - const RequestContextSharedPtr context_; - RequestHeaderDeserializer deserializer_; -}; - -/** - * Request parser uses a single deserializer to construct a request object - * This parser is responsible for consuming request-specific data (e.g. topic names) and always - * returns a parsed message - * @param RT request class - * @param BT deserializer type corresponding to request class (should be subclass of - * Deserializer) - */ -template class RequestParser : public Parser { +class AbstractRequest : public Message { public: - /** - * Create a parser with given context - * @param context parse context containing request header - */ - RequestParser(RequestContextSharedPtr context) : context_{context} {}; - - /** - * Consume enough data to fill in deserializer and receive the parsed request - * Fill in request's header with data stored in context - */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override { - context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); - if (deserializer.ready()) { - // after a successful parse, there should be nothing left - we have consumed all the bytes - ASSERT(0 == context_->remaining_request_size_); - RequestType request = deserializer.get(); - const RequestHeader& parsed_header = context_->request_header_; - request.setMetadata(parsed_header.correlation_id_, parsed_header.client_id_); - MessageSharedPtr msg = std::make_shared(request); - return ParseResponse::parsedMessage(msg); - } else { - return ParseResponse::stillWaiting(); - } - } + AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {}; protected: - RequestContextSharedPtr context_; - DeserializerType deserializer; // underlying request-specific deserializer + const RequestHeader request_header_; }; /** - * Abstract Kafka request - * Contains data present in every request - * @see http://kafka.apache.org/protocol.html#protocol_messages + * Concrete request that carries data particular to given request type */ -class Request : public Message { +template class ConcreteRequest : public AbstractRequest { public: /** * Request header fields need to be initialized by user in case of newly created requests */ - Request(int16_t api_key, int16_t api_version) : request_header_{api_key, api_version, 0, ""} {}; - - void setMetadata(const int32_t correlation_id, const NullableString& client_id) { - request_header_.correlation_id_ = correlation_id; - request_header_.client_id_ = client_id; - } + ConcreteRequest(const RequestHeader& request_header, const RequestData& data) + : AbstractRequest{request_header}, data_{data} {}; /** * Encodes given request into a buffer, with any extra configuration carried by the context */ - size_t encode(Buffer::Instance& dst, EncodingContext& context) const { + size_t encode(Buffer::Instance& dst) const override { + EncodingContext context{request_header_.api_version_}; size_t written{0}; // encode request header written += context.encode(request_header_.api_key_, dst); @@ -190,59 +65,34 @@ class Request : public Message { written += context.encode(request_header_.correlation_id_, dst); written += context.encode(request_header_.client_id_, dst); // encode request-specific data - written += encodeDetails(dst, context); + written += context.encode(data_, dst); return written; } -protected: - /** - * Encodes request-specific data into a buffer - */ - virtual size_t encodeDetails(Buffer::Instance&, EncodingContext&) const PURE; + bool operator==(const ConcreteRequest& rhs) const { + return request_header_ == rhs.request_header_ && data_ == rhs.data_; + }; - RequestHeader request_header_; +private: + const RequestData data_; }; /** * Request that did not have api_key & api_version that could be matched with any of * request-specific parsers */ -class UnknownRequest : public Request { +class UnknownRequest : public AbstractRequest { public: - UnknownRequest(const RequestHeader& request_header) - : Request{request_header.api_key_, request_header.api_version_} { - setMetadata(request_header.correlation_id_, request_header.client_id_); - }; + UnknownRequest(const RequestHeader& request_header) : AbstractRequest{request_header} {}; -protected: // this isn't the prettiest, as we have thrown away the data // XXX(adam.kotwasinski) discuss capturing the data as-is, and simply putting it back // this would add ability to forward unknown types of requests in cluster-proxy - size_t encodeDetails(Buffer::Instance&, EncodingContext&) const override { + size_t encode(Buffer::Instance&) const override { throw EnvoyException("cannot serialize unknown request"); } }; -/** - * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported - * api_key & api_version It does not attempt to capture any data, just throws it away until end of - * message - */ -class SentinelParser : public Parser { -public: - SentinelParser(RequestContextSharedPtr context) : context_{context} {}; - - /** - * Returns UnknownRequest - */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; - - const RequestContextSharedPtr contextForTest() const { return context_; } - -private: - const RequestContextSharedPtr context_; -}; - } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/kafka/kafka_request.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc similarity index 70% rename from source/extensions/filters/network/kafka/kafka_request.cc rename to source/extensions/filters/network/kafka/kafka_request_parser.cc index b41d0ea714c9a..5cadeccb069b7 100644 --- a/source/extensions/filters/network/kafka/kafka_request.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -1,7 +1,4 @@ -#include "extensions/filters/network/kafka/kafka_request.h" - -#include "extensions/filters/network/kafka/generated/requests.h" -#include "extensions/filters/network/kafka/parser.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" namespace Envoy { namespace Extensions { @@ -20,10 +17,20 @@ ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining } ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remaining) { - context_->remaining_request_size_ -= deserializer_.feed(buffer, remaining); + try { + context_->remaining_request_size_ -= deserializer_->feed(buffer, remaining); + } catch (const EnvoyException& e) { + buffer += context_->remaining_request_size_; + remaining -= context_->remaining_request_size_; + context_->remaining_request_size_ = 0; + + const RequestHeader header{-1, -1, -1, absl::nullopt}; + return ParseResponse::parsedMessage( + std::make_shared(context_->request_header_)); + } - if (deserializer_.ready()) { - RequestHeader request_header = deserializer_.get(); + if (deserializer_->ready()) { + RequestHeader request_header = deserializer_->get(); context_->request_header_ = request_header; ParserSharedPtr next_parser = parser_resolver_.createParser( request_header.api_key_, request_header.api_version_, context_); diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h new file mode 100644 index 0000000000000..0580e705784bb --- /dev/null +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -0,0 +1,204 @@ +#pragma once + +#include + +#include "envoy/common/exception.h" + +#include "common/common/assert.h" + +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Context that is shared between parsers that are handling the same single message + */ +struct RequestContext { + int32_t remaining_request_size_{0}; + RequestHeader request_header_{}; +}; + +typedef std::shared_ptr RequestContextSharedPtr; + +/** + * Configuration object + * Resolves the parser that will be responsible for consuming the request-specific data + * In other words: provides (api_key, api_version) -> Parser function + */ +class RequestParserResolver { +public: + virtual ~RequestParserResolver() = default; + + /** + * Creates a parser that is going to process data specific for given api_key & api_version + * @param api_key request type + * @param api_version request version + * @param context context to be used by parser + * @return parser that is capable of processing data for given request type & version + */ + virtual ParserSharedPtr createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const; + + /** + * Request parser singleton + */ + static const RequestParserResolver INSTANCE; +}; + +/** + * Request parser responsible for consuming request length and setting up context with this data + * @see http://kafka.apache.org/protocol.html#protocol_common + */ +class RequestStartParser : public Parser { +public: + RequestStartParser(const RequestParserResolver& parser_resolver) + : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; + + /** + * Consumes INT32 bytes as request length and updates the context with that value + * @return RequestHeaderParser instance to process request header + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + Int32Deserializer request_length_; +}; + +/** + * Deserializer that extracts request header (4 fields) + * Can throw, as one of the fields (client-id) can throw (nullable string with invalid length) + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +class RequestHeaderDeserializer + : public CompositeDeserializerWith4Delegates {}; + +typedef std::unique_ptr RequestHeaderDeserializerPtr; + +/** + * Parser responsible for computing request header and updating the context with data resolved + * On a successful parse uses resolved data (api_key & api_version) to determine next parser. + * @see http://kafka.apache.org/protocol.html#protocol_messages + */ +class RequestHeaderParser : public Parser { +public: + // default constructor + RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) + : RequestHeaderParser{parser_resolver, context, + std::make_unique()} {}; + + // visible for testing + RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context, + RequestHeaderDeserializerPtr deserializer) + : parser_resolver_{parser_resolver}, context_{context}, deserializer_{ + std::move(deserializer)} {}; + + /** + * Uses data provided to compute request header + * @return Parser instance responsible for processing rest of the message + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestParserResolver& parser_resolver_; + const RequestContextSharedPtr context_; + RequestHeaderDeserializerPtr deserializer_; +}; + +/** + * Request parser uses a single deserializer to construct a request object + * This parser is responsible for consuming request-specific data (e.g. topic names) and always + * returns a parsed message + * @param RequestType request class + * @param DeserializerType deserializer type corresponding to request class (should be subclass of + * Deserializer) + */ +template class RequestParser : public Parser { +public: + /** + * Create a parser with given context + * @param context parse context containing request header + */ + RequestParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Consume enough data to fill in deserializer and receive the parsed request + * Fill in request's header with data stored in context + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) override { + try { + context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); + } catch (const EnvoyException&) { + // treat the whole request as invalid, throw away the rest of the data + ignoreRestOfRequest(buffer, remaining); + return ParseResponse::parsedMessage( + std::make_shared(context_->request_header_)); + } + + if (deserializer.ready()) { + if (0 == context_->remaining_request_size_) { + // after a successful parse, there should be nothing left - we have consumed all the bytes + MessageSharedPtr msg = std::make_shared>( + context_->request_header_, deserializer.get()); + return ParseResponse::parsedMessage(msg); + } else { + // the message makes no sense, the deserializer that matches the schema consumed all + // necessary data, but there's still unconsumed bytes + ignoreRestOfRequest(buffer, remaining); + return ParseResponse::parsedMessage( + std::make_shared(context_->request_header_)); + } + } else { + return ParseResponse::stillWaiting(); + } + } + +protected: + RequestContextSharedPtr context_; + DeserializerType deserializer; // underlying request-specific deserializer + +private: + // moves the pointers until the end of request, so that the data that does not match the + // deserializers gets ignored + void ignoreRestOfRequest(const char*& buffer, uint64_t& remaining) { + buffer += context_->remaining_request_size_; + remaining -= context_->remaining_request_size_; + context_->remaining_request_size_ = 0; + } +}; + +/** + * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported + * api_key & api_version. It does not attempt to capture any data, just throws it away until end of + * message + */ +class SentinelParser : public Parser { +public: + SentinelParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Returns UnknownRequest + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestContextSharedPtr context_; +}; + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h index 7551888fe05e3..eec046295d002 100644 --- a/source/extensions/filters/network/kafka/message.h +++ b/source/extensions/filters/network/kafka/message.h @@ -3,6 +3,7 @@ #include #include +#include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" namespace Envoy { @@ -16,6 +17,12 @@ namespace Kafka { class Message { public: virtual ~Message() = default; + + /** + * Encode the contents of this message into a given buffer + * @param dst buffer instance to keep serialized message + */ + virtual size_t encode(Buffer::Instance& dst) const PURE; }; typedef std::shared_ptr MessageSharedPtr; diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 new file mode 100644 index 0000000000000..aaf4ebce48633 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 @@ -0,0 +1,37 @@ +struct {{ complex_type.name }} { + {% for field in complex_type.fields %} + const {{ field.field_declaration() }}_;{% endfor %} + {% for constructor in complex_type.compute_constructors() %} + // constructor used in versions: {{ constructor['versions'] }} + {{ constructor['full_declaration'] }}{% endfor %} + + {% if complex_type.fields|length > 0 %} + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + const int16_t api_version = encoder.apiVersion(); + size_t written{0};{% for field in complex_type.fields %} + if (api_version >= {{ field.version_usage[0] }} && api_version < {{ field.version_usage[-1] + 1 }}) { + written += encoder.encode({{ field.name }}_, dst); + }{% endfor %} + return written; + } + {% else %} + size_t encode(Buffer::Instance&, EncodingContext&) const { + return 0; + } + {% endif %} + + {% if complex_type.fields|length > 0 %} + bool operator==(const {{ complex_type.name }}& rhs) const { + {% else %} + bool operator==(const {{ complex_type.name }}&) const { + {% endif %} + return true{% for field in complex_type.fields %} + && {{ field.name }}_ == rhs.{{ field.name }}_{% endfor %}; + }; + +}; +{% for field_list in complex_type.compute_field_lists() %} +class {{ complex_type.name }}V{{ field_list.version }}Deserializer: + public CompositeDeserializerWith{{ field_list.field_count() }}Delegates<{{ complex_type.name }}{% for field in field_list.used_fields() %}, {{ field.deserializer_name_in_version(field_list.version) }}{% endfor %}>{}; +{% endfor %} + diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py new file mode 100755 index 0000000000000..3e4032c3b100b --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -0,0 +1,402 @@ +#!/usr/bin/python + +# usage: +# kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES +# where: +# COMMAND : 'generate-source', to generate source files +# 'generate-test', to generate test files +# OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', +# if generate-test: location of 'requests_test.cc' +# INPUT_FILES: Kafka protocol json files to be processed + +def main(): + import sys + import os + + command = sys.argv[1] + if 'generate-source' == command: + requests_h_file = os.path.abspath(sys.argv[2]) + kafka_request_resolver_cc_file = os.path.abspath(sys.argv[3]) + input_files = sys.argv[4:] + elif 'generate-test' == command: + requests_test_cc_file = os.path.abspath(sys.argv[2]) + input_files = sys.argv[3:] + else: + raise ValueError('invalid command: ' + command) + + import re + import json + + requests = [] + + for input_file in input_files: + with open(input_file, 'r') as fd: + raw_contents = fd.read() + without_comments = re.sub(r'//.*\n', '', raw_contents) + request_spec = json.loads(without_comments) + request = parse_request(request_spec) + if request is not None: # debugging + requests.append(request) + + requests.sort(key = lambda x: x.get_extra('api_key')) + + + if 'generate-source' == command: + complex_type_template = RenderingHelper.get_template('complex_type_template.j2') + request_parsers_template = RenderingHelper.get_template('request_parser.j2') + requests_h_contents = '' + + for request in requests: + # structures holding payload data + for dependency in request.declaration_chain: + requests_h_contents += complex_type_template.render(complex_type = dependency) + # request parser + requests_h_contents += request_parsers_template.render(complex_type = request) + + # full file with headers, namespace declaration etc. + requests_header_template = RenderingHelper.get_template('requests_h.j2') + contents = requests_header_template.render(contents = requests_h_contents) + + with open(requests_h_file, 'w') as fd: + fd.write(contents) + + kafka_request_resolver_template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') + contents = kafka_request_resolver_template.render(request_types = requests) + + with open(kafka_request_resolver_cc_file, 'w') as fd: + fd.write(contents) + + if 'generate-test' == command: + requests_test_template = RenderingHelper.get_template('requests_test_cc.j2') + contents = requests_test_template.render(request_types = requests) + + with open(requests_test_cc_file, 'w') as fd: + fd.write(contents) + +def parse_request(spec): + # a request is just a complex type, that has name & versions kept in differently named fields + request_type_name = spec['name'] + request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) + return parse_complex_type(request_type_name, spec, request_versions).with_extra('api_key', spec['apiKey']) + +def parse_complex_type(type_name, field_spec, versions): + fields = [] + for child_field in field_spec['fields']: + child = parse_field(child_field, versions[-1]) + fields.append(child) + return Complex(type_name, fields, versions) + +def parse_field(field_spec, highest_possible_version): + # obviously, field cannot be used in version higher than its type's usage + version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + version_usage_as_nullable = Statics.parse_version_string(field_spec['nullableVersions'], highest_possible_version) if 'nullableVersions' in field_spec else range(-1) + parsed_type = parse_type(field_spec['type'], field_spec, highest_possible_version) + return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) + +def parse_type(type_name, field_spec, highest_possible_version): + # array types are defined as `[]underlying_type` instead of having its own element with type inside :\ + if (type_name.startswith('[]')): + underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) + return Array(underlying_type) + else: + if (type_name in Primitive.PRIMITIVE_TYPE_NAMES): + return Primitive(type_name, field_spec.get('default')) + else: + versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + return parse_complex_type(type_name, field_spec, versions) + +class Statics: + + @staticmethod + def parse_version_string(raw_versions, highest_possible_version): + if raw_versions.endswith('+'): + return range(int(raw_versions[:-1]), highest_possible_version + 1) + else: + if '-' in raw_versions: + tokens = raw_versions.split('-', 1) + return range(int(tokens[0]), int(tokens[1]) + 1) + else: + single_version = int(raw_versions) + return range(single_version, single_version + 1) + +class FieldList: + + def __init__(self, version, fields): + self.version = version + self.fields = fields + + def used_fields(self): + return filter(lambda x: x.used_in_version(self.version), self.fields) + + def constructor_signature(self): + parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) + return ', '.join(parameter_spec) + + def constructor_init_list(self): + init_list = [] + for field in self.fields: + if field.used_in_version(self.version): + if field.is_nullable(): + if field.is_nullable_in_version(self.version): + # field is optional, and the parameter is optional in this version + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is optional, and the parameter is T in this version + init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is T, so parameter cannot be optional + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is not used in this version, so we need to put in default value + init_list_item = '%s_{%s}' % (field.name, field.default_value()) + init_list.append(init_list_item) + pass + return ', '.join(init_list) + + def field_count(self): + return len(self.used_fields()) + + def example_value(self): + return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) + +class FieldSpec: + + def __init__(self, name, type, version_usage, version_usage_as_nullable): + import re + separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() + self.type = type + self.version_usage = version_usage + self.version_usage_as_nullable = version_usage_as_nullable + + def is_nullable(self): + return len(self.version_usage_as_nullable) > 0 + + def is_nullable_in_version(self, version): + return version in self.version_usage_as_nullable + + def used_in_version(self, version): + return version in self.version_usage + + def field_declaration(self): + if self.is_nullable(): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def parameter_declaration(self, version): + if self.is_nullable_in_version(version): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def default_value(self): + if self.is_nullable(): + return '{%s}' % self.type.default_value() + else: + return str(self.type.default_value()) + + def example_value_for_test(self, version): + if self.is_nullable(): + return 'absl::make_optional<%s>(%s)' % (self.type.name, self.type.example_value_for_test(version)) + else: + return str(self.type.example_value_for_test(version)) + + def deserializer_name_in_version(self, version): + if self.is_nullable_in_version(version): + return 'Nullable%s' % self.type.deserializer_name_in_version(version) + else: + return self.type.deserializer_name_in_version(version) + + def is_printable(self): + return self.type.is_printable() + +class TypeSpecification: + + def deserializer_name_in_version(self, version): + raise NotImplementedError() + + def default_value(self): + raise NotImplementedError() + + def example_value_for_test(self, version): + raise NotImplementedError() + + def is_printable(self): + raise NotImplementedError() + +class Array(TypeSpecification): + + def __init__(self, underlying): + self.underlying = underlying + self.declaration_chain = self.underlying.declaration_chain + + @property + def name(self): + return 'std::vector<%s>' % self.underlying.name + + def deserializer_name_in_version(self, version): + return 'ArrayDeserializer<%s, %s>' % (self.underlying.name, self.underlying.deserializer_name_in_version(version) ) + + def default_value(self): + return '{}' + + def example_value_for_test(self, version): + return 'std::vector<%s>{ %s }' % (self.underlying.name, self.underlying.example_value_for_test(version)) + + def is_printable(self): + return self.underlying.is_printable() + +class Primitive(TypeSpecification): + + PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] + + KAFKA_TYPE_TO_ENVOY_TYPE = { + 'string': 'std::string', + 'bool': 'bool', + 'int8': 'int8_t', + 'int16': 'int16_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + 'bytes': 'Bytes', + } + + KAFKA_TYPE_TO_DESERIALIZER = { + 'string': 'StringDeserializer', + 'bool': 'BooleanDeserializer', + 'int8': 'Int8Deserializer', + 'int16': 'Int16Deserializer', + 'int32': 'Int32Deserializer', + 'int64': 'Int64Deserializer', + 'bytes': 'BytesDeserializer', + } + + # https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages + KAFKA_TYPE_TO_DEFAULT_VALUE = { + 'string': '""', + 'bool': 'false', + 'int8': '0', + 'int16': '0', + 'int32': '0', + 'int64': '0', + 'bytes': '{}', + } + + # to make test code more readable + KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { + 'string': '"string"', + 'bool': 'false', + 'int8': '8', + 'int16': '16', + 'int32': '32', + 'int64': '64ll', + 'bytes': 'Bytes({0, 1, 2, 3})', + } + + def __init__(self, name, custom_default_value): + self.original_name = name + self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) + self.custom_default_value = custom_default_value + self.declaration_chain = [] + self.deserializer_name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) + + @staticmethod + def compute(name, map): + if name in map: + return map[name] + else: + raise ValueError(name) + + def deserializer_name_in_version(self, version): + return self.deserializer_name + + def default_value(self): + if self.custom_default_value is not None: + return self.custom_default_value + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) + + def example_value_for_test(self, version): + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) + + def is_printable(self): + return self.name not in ['Bytes'] + +class Complex(TypeSpecification): + + def __init__(self, name, fields, versions): + self.name = name + self.fields = fields + self.versions = versions + self.declaration_chain = self.__compute_declaration_chain() + self.attributes = {} + + def __compute_declaration_chain(self): + result = [] + for field in self.fields: + result.extend(field.type.declaration_chain) + result.append(self) + return result + + def with_extra(self, key, value): + self.attributes[key] = value + return self + + def get_extra(self, key): + return self.attributes[key] + + def compute_constructors(self): + # field lists for different versions may not differ (as Kafka can bump version without any changes) + # but constructors need to be unique + signature_to_constructor = {} + for field_list in self.compute_field_lists(): + signature = field_list.constructor_signature() + constructor = signature_to_constructor.get(signature) + if constructor is None: + entry = {} + entry['versions'] = [ field_list.version ] + entry['signature'] = signature + if (len(signature) > 0): + entry['full_declaration'] = '%s(%s): %s {};' % (self.name, signature, field_list.constructor_init_list()) + else: + entry['full_declaration'] = '%s() {};' % self.name + signature_to_constructor[signature] = entry + else: + constructor['versions'].append(field_list.version) + return sorted(signature_to_constructor.values(), key = lambda x: x['versions'][0]) + + def compute_field_lists(self): + field_lists = [] + for version in self.versions: + field_list = FieldList(version, self.fields) + field_lists.append(field_list) + return field_lists; + + def deserializer_name_in_version(self, version): + return '%sV%dDeserializer' % (self.name, version) + + def default_value(self): + raise NotImplementedError('unable to create default value of complex type') + + def example_value_for_test(self, version): + field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) + example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) + return '%s(%s)' % (self.name, ', '.join(example_values)) + + def is_printable(self): + return True + +class RenderingHelper: + + @staticmethod + def get_template(template): + import jinja2 + import os + env = jinja2.Environment(loader = jinja2.FileSystemLoader(searchpath = os.path.dirname(os.path.abspath(__file__)))) + return env.get_template(template) + +if __name__ == "__main__": + main() diff --git a/source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 similarity index 51% rename from source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc rename to source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index 2e8c26f25bad5..df08eec07a074 100644 --- a/source/extensions/filters/network/kafka/generated/kafka_request_resolver.cc +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -1,7 +1,7 @@ // DO NOT EDIT - THIS FILE WAS GENERATED // clang-format off -#include "extensions/filters/network/kafka/generated/requests.h" -#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/requests.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" #include "extensions/filters/network/kafka/parser.h" namespace Envoy { @@ -14,18 +14,10 @@ const RequestParserResolver RequestParserResolver::INSTANCE; ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, RequestContextSharedPtr context) const { - if (8 == api_key && 0 == api_version) { - return std::make_shared(context); - } - if (8 == api_key && 1 == api_version) { - return std::make_shared(context); - } - if (8 == api_key && 2 == api_version) { - return std::make_shared(context); - } - if (8 == api_key && 3 == api_version) { - return std::make_shared(context); - } +{% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} + if ({{ request_type.get_extra('api_key') }} == api_key && {{ field_list.version }} == api_version) { + return std::make_shared<{{ request_type.name }}V{{ field_list.version }}Parser>(context); + }{% endfor %}{% endfor %} return std::make_shared(context); } diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 new file mode 100644 index 0000000000000..8f6d655e1fa31 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -0,0 +1,6 @@ +{% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { +public: + {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; +}; + +{% endfor %} \ No newline at end of file diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 new file mode 100644 index 0000000000000..d6f454369c9de --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 @@ -0,0 +1,15 @@ +// DO NOT EDIT - THIS FILE WAS GENERATED +// clang-format off +#pragma once +#include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +{{ contents }} + +}}}} +// clang-format on diff --git a/test/extensions/filters/network/kafka/generated/requests_test.cc b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 similarity index 52% rename from test/extensions/filters/network/kafka/generated/requests_test.cc rename to source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 43a7b08599b86..45da13b677f3d 100644 --- a/test/extensions/filters/network/kafka/generated/requests_test.cc +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -1,6 +1,6 @@ // DO NOT EDIT - THIS FILE WAS GENERATED // clang-format off -#include "extensions/filters/network/kafka/generated/requests.h" +#include "extensions/filters/network/kafka/requests.h" #include "extensions/filters/network/kafka/request_codec.h" #include "test/mocks/server/mocks.h" @@ -26,7 +26,7 @@ class MockMessageListener : public RequestCallback { }; template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { - RequestEncoder serializer{buffer_}; + MessageEncoderImpl serializer{buffer_}; serializer.encode(request); std::shared_ptr mock_listener = std::make_shared(); @@ -39,46 +39,11 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria return std::dynamic_pointer_cast(receivedMessage); }; - -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV0) { - // given - OffsetCommitRequestV0 request = {"string", {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV1) { - // given - OffsetCommitRequestV1 request = {"string", 32, "string", {{ {"string", {{ {32, 64, 64, {"nullable"}, } }}, } }}, }; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV2) { - // given - OffsetCommitRequestV2 request = {"string", 32, "string", 64, {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; - - // when - auto received = serializeAndDeserialize(request); - - // then - ASSERT_NE(received, nullptr); - ASSERT_EQ(*received, request); -} - -TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV3) { +{% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} +TEST_F(RequestDecoderTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { // given - OffsetCommitRequestV3 request = {"string", 32, "string", 64, {{ {"string", {{ {32, 64, {"nullable"}, } }}, } }}, }; + {{ request_type.name }} data = { {{ field_list.example_value() }} }; + ConcreteRequest<{{ request_type.name }}> request = { { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, 0, absl::nullopt }, data }; // when auto received = serializeAndDeserialize(request); @@ -87,6 +52,7 @@ TEST_F(RequestDecoderTest, shouldParseOffsetCommitRequestV3) { ASSERT_NE(received, nullptr); ASSERT_EQ(*received, request); } +{% endfor %}{% endfor %} } // namespace Kafka } // namespace NetworkFilters diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 9ef3a54471758..cc8c799749bcf 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -53,15 +53,15 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl } } -void RequestEncoder::encode(const Request& message) { - EncodingContext encoder; +void MessageEncoderImpl::encode(const Message& message) { Buffer::OwnedImpl data_buffer; // TODO (adam.kotwasinski) precompute the size instead of using temporary // also, when we have 'computeSize' method, then we can push encoding request's size into // Request::encode - int32_t data_len = encoder.encode(message, data_buffer); // encode data computing data length - encoder.encode(data_len, output_); // encode data length into result - output_.add(data_buffer); // copy data into result + int32_t data_len = message.encode(data_buffer); // encode data computing data length + EncodingContext encoder{-1}; + encoder.encode(data_len, output_); // encode data length into result + output_.add(data_buffer); // copy data into result } } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index 65c2f36e51590..c856b9461fcdb 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -5,8 +5,8 @@ #include "extensions/filters/network/kafka/codec.h" #include "extensions/filters/network/kafka/kafka_request.h" +#include "extensions/filters/network/kafka/kafka_request_parser.h" #include "extensions/filters/network/kafka/parser.h" -#include "extensions/filters/network/kafka/serialization.h" namespace Envoy { namespace Extensions { @@ -38,7 +38,7 @@ typedef std::shared_ptr RequestCallbackSharedPtr; * Stores parse state (have `onData` invoked multiple times for messages that are larger than single * buffer) */ -class RequestDecoder : public MessageDecoder, public Logger::Loggable { +class RequestDecoder : public MessageDecoder { public: /** * Creates a decoder that can decode requests specified by RequestParserResolver, notifying @@ -69,19 +69,19 @@ class RequestDecoder : public MessageDecoder, public Logger::Loggable { +class MessageEncoderImpl : public MessageEncoder { public: /** * Wraps buffer with encoder */ - RequestEncoder(Buffer::Instance& output) : output_(output) {} + MessageEncoderImpl(Buffer::Instance& output) : output_(output) {} /** * Encodes request into wrapped buffer */ - void encode(const Request& message) override; + void encode(const Message& message) override; private: Buffer::Instance& output_; diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 05a8f1e4ae0da..1b62c516f3574 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -178,6 +178,9 @@ class BooleanDeserializer : public Deserializer { */ class StringDeserializer : public Deserializer { public: + /** + * Can throw EnvoyException if given string length is not valid + */ size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { @@ -190,7 +193,7 @@ class StringDeserializer : public Deserializer { if (required_ >= 0) { data_buf_ = std::vector(required_); } else { - throw EnvoyException(fmt::format("invalid std::string length: {}", required_)); + throw EnvoyException(fmt::format("invalid STRING length: {}", required_)); } length_consumed_ = true; } @@ -237,6 +240,9 @@ class StringDeserializer : public Deserializer { */ class NullableStringDeserializer : public Deserializer { public: + /** + * Can throw EnvoyException if given string length is not valid + */ size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { @@ -307,6 +313,9 @@ class NullableStringDeserializer : public Deserializer { */ class BytesDeserializer : public Deserializer { public: + /** + * Can throw EnvoyException if given bytes length is not valid + */ size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { @@ -364,6 +373,9 @@ class BytesDeserializer : public Deserializer { */ class NullableBytesDeserializer : public Deserializer { public: + /** + * Can throw EnvoyException if given bytes length is not valid + */ size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { @@ -442,8 +454,88 @@ class NullableBytesDeserializer : public Deserializer { * follow. A null array is represented with a length of -1. */ template -class ArrayDeserializer : public Deserializer> { +class ArrayDeserializer : public Deserializer> { +public: + /** + * Can throw EnvoyException if array length is invalid or if DeserializerType can throw + */ + size_t feed(const char*& buffer, uint64_t& remaining) { + + const size_t length_consumed = length_buf_.feed(buffer, remaining); + if (!length_buf_.ready()) { + // break early: we still need to fill in length buffer + return length_consumed; + } + + if (!length_consumed_) { + required_ = length_buf_.get(); + if (required_ >= 0) { + children_ = std::vector(required_); + } else { + throw EnvoyException(fmt::format("invalid ARRAY length: {}", required_)); + } + length_consumed_ = true; + } + + if (ready_) { + return length_consumed; + } + + size_t child_consumed{0}; + for (DeserializerType& child : children_) { + child_consumed += child.feed(buffer, remaining); + } + + bool children_ready_ = true; + for (DeserializerType& child : children_) { + children_ready_ &= child.ready(); + } + ready_ = children_ready_; + + return length_consumed + child_consumed; + } + + bool ready() const { return ready_; } + + std::vector get() const { + std::vector result{}; + result.reserve(children_.size()); + for (const DeserializerType& child : children_) { + const ResponseType child_result = child.get(); + result.push_back(child_result); + } + return result; + } + +private: + Int32Deserializer length_buf_; + bool length_consumed_{false}; + int32_t required_; + std::vector children_; + bool children_setup_{false}; + bool ready_{false}; +}; + +/** + * Deserializer for nullable array of objects of the same type + * + * First reads the length of the array, then initializes N underlying deserializers of type + * DeserializerType After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector + * @param ResponseType result type returned by deserializer of type DeserializerType + * @param DeserializerType underlying deserializer type + * + * From documentation: + * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. + * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T + * follow. A null array is represented with a length of -1. + */ +template +class NullableArrayDeserializer : public Deserializer> { public: + /** + * Can throw EnvoyException if array length is invalid or if DeserializerType can throw + */ size_t feed(const char*& buffer, uint64_t& remaining) { const size_t length_consumed = length_buf_.feed(buffer, remaining); @@ -462,7 +554,7 @@ class ArrayDeserializer : public Deserializer> { ready_ = true; } if (required_ < NULL_ARRAY_LENGTH) { - throw EnvoyException(fmt::format("invalid array length: {}", required_)); + throw EnvoyException(fmt::format("invalid NULLABLE_ARRAY length: {}", required_)); } length_consumed_ = true; @@ -496,7 +588,7 @@ class ArrayDeserializer : public Deserializer> { const ResponseType child_result = child.get(); result.push_back(child_result); } - return {result}; + return result; } else { return absl::nullopt; } @@ -517,9 +609,17 @@ class ArrayDeserializer : public Deserializer> { * Encodes provided argument in Kafka format * In case of primitive types, this is done explicitly as per spec * In case of composite types, this is done by calling 'encode' on provided argument + * + * This object also carries extra information that is used while traversing the request + * structure-tree during encryping (currently api_version, as different request versions serialize + * differently) */ +// XXX (adam.kotwasinski) that class might be split into Request/ResponseEncodingContext in future, +// but leaving it as it is now class EncodingContext { public: + EncodingContext(int16_t api_version) : api_version_{api_version} {}; + /** * Encode given reference in a buffer * @return bytes written @@ -530,7 +630,18 @@ class EncodingContext { * Encode given array in a buffer * @return bytes written */ + template size_t encode(const std::vector& arg, Buffer::Instance& dst); + + /** + * Encode given nullable array in a buffer + * @return bytes written + */ template size_t encode(const NullableArray& arg, Buffer::Instance& dst); + + int16_t apiVersion() const { return api_version_; } + +private: + const int16_t api_version_; }; /** @@ -556,7 +667,7 @@ template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Ins */ #define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ template <> inline size_t EncodingContext::encode(const TYPE& arg, Buffer::Instance& dst) { \ - TYPE val = CONVERTER(arg); \ + const TYPE val = CONVERTER(arg); \ dst.add(&val, sizeof(TYPE)); \ return sizeof(TYPE); \ } @@ -596,7 +707,7 @@ inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instanc if (arg.has_value()) { return encode(*arg, dst); } else { - int16_t len = -1; + const int16_t len = -1; return encode(len, dst); } } @@ -606,8 +717,8 @@ inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instanc * Encode byte array as INT32 length + N bytes */ template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Instance& dst) { - int32_t data_length = arg.size(); - size_t header_length = encode(data_length, dst); + const int32_t data_length = arg.size(); + const size_t header_length = encode(data_length, dst); dst.add(arg.data(), arg.size()); return header_length + data_length; } @@ -620,11 +731,21 @@ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buff if (arg.has_value()) { return encode(*arg, dst); } else { - int32_t len = -1; + const int32_t len = -1; return encode(len, dst); } } +/** + * Encode nullable object array to T as INT32 length + N elements + * Each element of type T then serializes itself on its own + */ +template +size_t EncodingContext::encode(const std::vector& arg, Buffer::Instance& dst) { + const NullableArray wrapped = {arg}; + return encode(wrapped, dst); +} + /** * Encode nullable object array to T as INT32 length + N elements (length = -1 for null) * Each element of type T then serializes itself on its own @@ -632,8 +753,8 @@ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buff template size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& dst) { if (arg.has_value()) { - int32_t len = arg->size(); - size_t header_length = encode(len, dst); + const int32_t len = arg->size(); + const size_t header_length = encode(len, dst); size_t written{0}; for (const T& el : *arg) { // for each of array elements, resolve the correct method again @@ -642,7 +763,7 @@ size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& ds } return header_length + written; } else { - int32_t len = -1; + const int32_t len = -1; return encode(len, dst); } } diff --git a/source/extensions/filters/network/kafka/generated/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h similarity index 80% rename from source/extensions/filters/network/kafka/generated/serialization_composite.h rename to source/extensions/filters/network/kafka/serialization_composite.h index 2a7b2ae261516..2c10291b594bd 100644 --- a/source/extensions/filters/network/kafka/generated/serialization_composite.h +++ b/source/extensions/filters/network/kafka/serialization_composite.h @@ -1,5 +1,5 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off +// FIXME(adam.kotwasinski) this file can be generated, as it's repeating the same code for 0..9 +// delegates #pragma once #include @@ -26,6 +26,7 @@ namespace Kafka { * This header contains only composite deserializers * The basic design is composite deserializer creating delegates DeserializerType1..N * Result of type ResponseType is constructed by getting results of each of delegates + * These deserializers can throw, if any of the delegate deserializers can */ /** @@ -38,26 +39,16 @@ namespace Kafka { * * @param ResponseType type of deserialized data */ -template < - typename ResponseType -> +template class CompositeDeserializerWith0Delegates : public Deserializer { public: - CompositeDeserializerWith0Delegates(){}; - size_t feed(const char*&, uint64_t&) { - return 0; - } + size_t feed(const char*&, uint64_t&) { return 0; } - bool ready() const { - return true; - } + bool ready() const { return true; } - ResponseType get() const { - return { - }; - } + ResponseType get() const { return {}; } protected: }; @@ -73,13 +64,9 @@ class CompositeDeserializerWith0Delegates : public Deserializer { * @param ResponseType type of deserialized data * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1 -> +template class CompositeDeserializerWith1Delegates : public Deserializer { public: - CompositeDeserializerWith1Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -88,15 +75,9 @@ class CompositeDeserializerWith1Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate1_.ready(); - } + bool ready() const { return delegate1_.ready(); } - ResponseType get() const { - return { - delegate1_.get() - }; - } + ResponseType get() const { return {delegate1_.get()}; } protected: DeserializerType1 delegate1_; @@ -114,14 +95,9 @@ class CompositeDeserializerWith1Delegates : public Deserializer { * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2 -> +template class CompositeDeserializerWith2Delegates : public Deserializer { public: - CompositeDeserializerWith2Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -131,16 +107,9 @@ class CompositeDeserializerWith2Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate2_.ready(); - } + bool ready() const { return delegate2_.ready(); } - ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get() - }; - } + ResponseType get() const { return {delegate1_.get(), delegate2_.get()}; } protected: DeserializerType1 delegate1_; @@ -160,15 +129,10 @@ class CompositeDeserializerWith2Delegates : public Deserializer { * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3 -> +template class CompositeDeserializerWith3Delegates : public Deserializer { public: - CompositeDeserializerWith3Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -179,17 +143,9 @@ class CompositeDeserializerWith3Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate3_.ready(); - } + bool ready() const { return delegate3_.ready(); } - ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get() - }; - } + ResponseType get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } protected: DeserializerType1 delegate1_; @@ -211,16 +167,10 @@ class CompositeDeserializerWith3Delegates : public Deserializer { * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4 -> +template class CompositeDeserializerWith4Delegates : public Deserializer { public: - CompositeDeserializerWith4Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -232,17 +182,10 @@ class CompositeDeserializerWith4Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate4_.ready(); - } + bool ready() const { return delegate4_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; } protected: @@ -267,17 +210,10 @@ class CompositeDeserializerWith4Delegates : public Deserializer { * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4, - typename DeserializerType5 -> +template class CompositeDeserializerWith5Delegates : public Deserializer { public: - CompositeDeserializerWith5Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -290,18 +226,11 @@ class CompositeDeserializerWith5Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate5_.ready(); - } + bool ready() const { return delegate5_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get(), - delegate5_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), + delegate5_.get()}; } protected: @@ -328,18 +257,11 @@ class CompositeDeserializerWith5Delegates : public Deserializer { * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4, - typename DeserializerType5, - typename DeserializerType6 -> +template class CompositeDeserializerWith6Delegates : public Deserializer { public: - CompositeDeserializerWith6Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -353,19 +275,11 @@ class CompositeDeserializerWith6Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate6_.ready(); - } + bool ready() const { return delegate6_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get(), - delegate5_.get(), - delegate6_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), + delegate4_.get(), delegate5_.get(), delegate6_.get()}; } protected: @@ -394,19 +308,11 @@ class CompositeDeserializerWith6Delegates : public Deserializer { * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4, - typename DeserializerType5, - typename DeserializerType6, - typename DeserializerType7 -> +template class CompositeDeserializerWith7Delegates : public Deserializer { public: - CompositeDeserializerWith7Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -421,20 +327,11 @@ class CompositeDeserializerWith7Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate7_.ready(); - } + bool ready() const { return delegate7_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get(), - delegate5_.get(), - delegate6_.get(), - delegate7_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), + delegate5_.get(), delegate6_.get(), delegate7_.get()}; } protected: @@ -465,20 +362,11 @@ class CompositeDeserializerWith7Delegates : public Deserializer { * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4, - typename DeserializerType5, - typename DeserializerType6, - typename DeserializerType7, - typename DeserializerType8 -> +template class CompositeDeserializerWith8Delegates : public Deserializer { public: - CompositeDeserializerWith8Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -494,21 +382,11 @@ class CompositeDeserializerWith8Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate8_.ready(); - } + bool ready() const { return delegate8_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get(), - delegate5_.get(), - delegate6_.get(), - delegate7_.get(), - delegate8_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), + delegate5_.get(), delegate6_.get(), delegate7_.get(), delegate8_.get()}; } protected: @@ -541,21 +419,12 @@ class CompositeDeserializerWith8Delegates : public Deserializer { * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) * @param DeserializerType9 deserializer 9 (result used as argument 9 of ResponseType's ctor) */ -template < - typename ResponseType, - typename DeserializerType1, - typename DeserializerType2, - typename DeserializerType3, - typename DeserializerType4, - typename DeserializerType5, - typename DeserializerType6, - typename DeserializerType7, - typename DeserializerType8, - typename DeserializerType9 -> +template class CompositeDeserializerWith9Delegates : public Deserializer { public: - CompositeDeserializerWith9Delegates(){}; size_t feed(const char*& buffer, uint64_t& remaining) { @@ -572,22 +441,12 @@ class CompositeDeserializerWith9Delegates : public Deserializer { return consumed; } - bool ready() const { - return delegate9_.ready(); - } + bool ready() const { return delegate9_.ready(); } ResponseType get() const { - return { - delegate1_.get(), - delegate2_.get(), - delegate3_.get(), - delegate4_.get(), - delegate5_.get(), - delegate6_.get(), - delegate7_.get(), - delegate8_.get(), - delegate9_.get() - }; + return {delegate1_.get(), delegate2_.get(), delegate3_.get(), + delegate4_.get(), delegate5_.get(), delegate6_.get(), + delegate7_.get(), delegate8_.get(), delegate9_.get()}; } protected: @@ -606,4 +465,3 @@ class CompositeDeserializerWith9Delegates : public Deserializer { } // namespace NetworkFilters } // namespace Extensions } // namespace Envoy -// clang-format on diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index cc4ad6c12bcac..f767e0a2b24c8 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -23,7 +23,7 @@ envoy_extension_cc_test( envoy_extension_cc_test( name = "serialization_composite_test", - srcs = ["generated/serialization_composite_test.cc"], + srcs = ["serialization_composite_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ "//source/extensions/filters/network/kafka:serialization_lib", @@ -32,8 +32,8 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "kafka_request_test", - srcs = ["kafka_request_test.cc"], + name = "kafka_request_parser_test", + srcs = ["kafka_request_parser_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ "//source/extensions/filters/network/kafka:kafka_request_lib", @@ -53,10 +53,23 @@ envoy_extension_cc_test( envoy_extension_cc_test( name = "requests_test", - srcs = ["generated/requests_test.cc"], + srcs = ["requests_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "//test/mocks/server:server_mocks", ], ) + +genrule( + name = "kafka_generated_test", + srcs = [ + "@kafka_produce_request_spec//file", + "@kafka_fetch_request_spec//file", + ], + outs = ["requests_test.cc"], + cmd = "./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test $(location requests_test.cc) $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file)", + tools = [ + "//source/extensions/filters/network/kafka:kafka_code_generator", + ], +) diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc new file mode 100644 index 0000000000000..64e886c17b279 --- /dev/null +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -0,0 +1,249 @@ +#include "common/common/stack_array.h" + +#include "extensions/filters/network/kafka/kafka_request_parser.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class BufferBasedTest : public testing::Test { +public: + Buffer::OwnedImpl& buffer() { return buffer_; } + + const char* getBytes() { + uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer_.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); + } + +protected: + Buffer::OwnedImpl buffer_; + EncodingContext encoder_{-1}; // api_version is not used for request header +}; + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver(){}; + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); +}; + +TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { + // given + MockRequestParserResolver resolver{}; + RequestStartParser testee{resolver}; + + int32_t request_len = 1234; + encoder_.encode(request_len, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = 1024; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); +} + +class MockParser : public Parser { +public: + ParseResponse parse(const char*&, uint64_t&) { + throw new EnvoyException("should not be invoked"); + } +}; + +TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { + // given + const MockRequestParserResolver parser_resolver; + const ParserSharedPtr parser{new MockParser{}}; + EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); + + const int32_t request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + RequestHeaderParser testee{parser_resolver, context}; + + const int16_t api_key{1}; + const int16_t api_version{2}; + const int32_t correlation_id{10}; + const NullableString client_id{"aaa"}; + size_t written = 0; + written += encoder_.encode(api_key, buffer()); + written += encoder_.encode(api_version, buffer()); + written += encoder_.encode(correlation_id, buffer()); + written += encoder_.encode(client_id, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = 100000; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, parser); + ASSERT_EQ(result.message_, nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - written); + ASSERT_EQ(remaining, orig_remaining - written); + + const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; + ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); +} + +TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDuringFeeding) { + // given + + // throws during feeding + class ThrowingRequestHeaderDeserializer : public RequestHeaderDeserializer { + public: + size_t feed(const char*&, uint64_t&) { throw EnvoyException("feed"); }; + + bool ready() const { throw std::runtime_error("should not be invoked at all"); }; + + RequestHeader get() const { throw std::runtime_error("should not be invoked at all"); }; + }; + + const MockRequestParserResolver parser_resolver; + + const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request + RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + RequestHeaderParser testee{parser_resolver, request_context, + std::make_unique()}; + + const char* bytes = getBytes(); + const char* orig_bytes = bytes; + uint64_t remaining = 100000; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + + ASSERT_EQ(bytes, orig_bytes + request_size); + ASSERT_EQ(remaining, orig_remaining - request_size); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); +} + +TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { + // given + + // throws during feeding + class ThrowingDeserializer : public Deserializer { + public: + size_t feed(const char*&, uint64_t&) { throw EnvoyException("feed"); }; + + bool ready() const { throw std::runtime_error("should not be invoked at all"); }; + + int32_t get() const { throw std::runtime_error("should not be invoked at all"); }; + }; + + const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request + RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + + RequestParser testee{request_context}; + + const char* bytes = getBytes(); + const char* orig_bytes = bytes; + uint64_t remaining = 100000; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + + ASSERT_EQ(bytes, orig_bytes + request_size); + ASSERT_EQ(remaining, orig_remaining - request_size); +} + +// deserializer that consumes 4 bytes and returns 0 +class FourBytesDeserializer : public Deserializer { +public: + size_t feed(const char*& buffer, uint64_t& remaining) { + buffer += 4; + remaining -= 4; + return 4; + }; + + bool ready() const { return true; }; + + int32_t get() const { return 0; }; +}; + +TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyButLeavingData) { + // given + const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request + RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; + + RequestParser testee{request_context}; + + const char* bytes = getBytes(); + const char* orig_bytes = bytes; + uint64_t remaining = 100000; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + + ASSERT_EQ(bytes, orig_bytes + request_size); + ASSERT_EQ(remaining, orig_remaining - request_size); +} + +TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { + // given + const int32_t request_len = 1000; + RequestContextSharedPtr context{new RequestContext()}; + context->remaining_request_size_ = request_len; + SentinelParser testee{context}; + + const Bytes garbage(request_len * 2); + encoder_.encode(garbage, buffer()); + + const char* bytes = getBytes(); + uint64_t remaining = request_len * 2; + const uint64_t orig_remaining = remaining; + + // when + const ParseResponse result = testee.parse(bytes, remaining); + + // then + ASSERT_EQ(result.hasData(), true); + ASSERT_EQ(result.next_parser_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); + ASSERT_EQ(remaining, orig_remaining - request_len); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/kafka_request_test.cc b/test/extensions/filters/network/kafka/kafka_request_test.cc deleted file mode 100644 index a78fd835b30d2..0000000000000 --- a/test/extensions/filters/network/kafka/kafka_request_test.cc +++ /dev/null @@ -1,131 +0,0 @@ -#include "common/common/stack_array.h" - -#include "extensions/filters/network/kafka/generated/requests.h" -#include "extensions/filters/network/kafka/kafka_request.h" - -#include "test/mocks/server/mocks.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::_; -using testing::Return; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -class BufferBasedTest : public testing::Test { -public: - Buffer::OwnedImpl& buffer() { return buffer_; } - - const char* getBytes() { - uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); - STACK_ARRAY(slices, Buffer::RawSlice, num_slices); - buffer_.getRawSlices(slices.begin(), num_slices); - return reinterpret_cast((slices[0]).mem_); - } - -protected: - Buffer::OwnedImpl buffer_; - EncodingContext encoder_; -}; - -class MockRequestParserResolver : public RequestParserResolver { -public: - MockRequestParserResolver(){}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); -}; - -TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { - // given - MockRequestParserResolver resolver{}; - RequestStartParser testee{resolver}; - - int32_t request_len = 1234; - encoder_.encode(request_len, buffer()); - - const char* bytes = getBytes(); - uint64_t remaining = 1024; - - // when - const ParseResponse result = testee.parse(bytes, remaining); - - // then - ASSERT_EQ(result.hasData(), true); - ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); - ASSERT_EQ(result.message_, nullptr); - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); -} - -TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { - // given - const MockRequestParserResolver parser_resolver; - const ParserSharedPtr parser{new OffsetCommitRequestV0Parser{nullptr}}; - EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); - - const int32_t request_len = 1000; - RequestContextSharedPtr context{new RequestContext()}; - context->remaining_request_size_ = request_len; - RequestHeaderParser testee{parser_resolver, context}; - - const int16_t api_key{1}; - const int16_t api_version{2}; - const int32_t correlation_id{10}; - const NullableString client_id{"aaa"}; - size_t written = 0; - written += encoder_.encode(api_key, buffer()); - written += encoder_.encode(api_version, buffer()); - written += encoder_.encode(correlation_id, buffer()); - written += encoder_.encode(client_id, buffer()); - - const char* bytes = getBytes(); - uint64_t remaining = 100000; - const uint64_t orig_remaining = remaining; - - // when - const ParseResponse result = testee.parse(bytes, remaining); - - // then - ASSERT_EQ(result.hasData(), true); - ASSERT_EQ(result.next_parser_, parser); - ASSERT_EQ(result.message_, nullptr); - - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - written); - ASSERT_EQ(remaining, orig_remaining - written); - - const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; - ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); -} - -TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { - // given - const int32_t request_len = 1000; - RequestContextSharedPtr context{new RequestContext()}; - context->remaining_request_size_ = request_len; - SentinelParser testee{context}; - - const Bytes garbage(request_len * 2); - encoder_.encode(garbage, buffer()); - - const char* bytes = getBytes(); - uint64_t remaining = request_len * 2; - const uint64_t orig_remaining = remaining; - - // when - const ParseResponse result = testee.parse(bytes, remaining); - - // then - ASSERT_EQ(result.hasData(), true); - ASSERT_EQ(result.next_parser_, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); - - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); - ASSERT_EQ(remaining, orig_remaining - request_len); -} - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc index 0c408c9d32a0a..d273b6b1b20ab 100644 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -1,4 +1,3 @@ -#include "extensions/filters/network/kafka/generated/requests.h" #include "extensions/filters/network/kafka/request_codec.h" #include "test/mocks/server/mocks.h" @@ -32,7 +31,7 @@ class MockRequestParserResolver : public RequestParserResolver { }; template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { - RequestEncoder serializer{buffer_}; + MessageEncoderImpl serializer{buffer_}; serializer.encode(request); std::shared_ptr mock_listener = std::make_shared(); @@ -51,12 +50,28 @@ ParserSharedPtr createSentinelParser(testing::Unused, testing::Unused, return std::make_shared(context); } +struct MockRequest { + const int32_t field1_ = 1; + const int64_t field2_ = 2; + + size_t encode(Buffer::Instance& buffer, EncodingContext& encoder) const { + size_t written{0}; + written += encoder.encode(field1_, buffer); + written += encoder.encode(field2_, buffer); + return written; + } + + friend std::ostream& operator<<(std::ostream& os, const MockRequest&) { + return os << "{MockRequest}"; + }; +}; + TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given - RequestEncoder serializer{buffer_}; - NullableArray topics{{{"topic1", {{{{0, 10, "m1"}}}}}}}; - OffsetCommitRequestV0 request{"group_id", topics}; - request.setMetadata(42, "client-id"); + MessageEncoderImpl serializer{buffer_}; + MockRequest data{}; + // api key & version values do not matter, as resolver recognizes nothing + ConcreteRequest request = {{1000, 2000, 3000, "correlation-id"}, data}; serializer.encode(request); diff --git a/test/extensions/filters/network/kafka/generated/serialization_composite_test.cc b/test/extensions/filters/network/kafka/serialization_composite_test.cc similarity index 83% rename from test/extensions/filters/network/kafka/generated/serialization_composite_test.cc rename to test/extensions/filters/network/kafka/serialization_composite_test.cc index d310bf4530e15..d406ae214b63d 100644 --- a/test/extensions/filters/network/kafka/generated/serialization_composite_test.cc +++ b/test/extensions/filters/network/kafka/serialization_composite_test.cc @@ -1,9 +1,9 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off +// FIXME(adam.kotwasinski) this file can be generated, as it's repeating the same code for 0..9 +// delegates #include "common/common/stack_array.h" -#include "extensions/filters/network/kafka/generated/serialization_composite.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/serialization_composite.h" #include "test/mocks/server/mocks.h" @@ -41,7 +41,7 @@ void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { BT testee{}; Buffer::OwnedImpl buffer; - EncodingContext encoder; + EncodingContext encoder{-1}; const size_t written = encoder.encode(expected, buffer); uint64_t remaining = @@ -79,7 +79,7 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { BT testee{}; Buffer::OwnedImpl buffer; - EncodingContext encoder; + EncodingContext encoder{-1}; const size_t written = encoder.encode(expected, buffer); const char* data = getRawData(buffer); @@ -119,17 +119,12 @@ template void serializeThenDeserializeAndCheckEqualit struct CompositeResultWith0Fields { - size_t encode(Buffer::Instance&, EncodingContext&) const { - return 0; - } + size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } - bool operator==(const CompositeResultWith0Fields&) const { - return true; - } + bool operator==(const CompositeResultWith0Fields&) const { return true; } }; -typedef CompositeDeserializerWith0Delegates - TestCompositeDeserializer0; +typedef CompositeDeserializerWith0Delegates TestCompositeDeserializer0; // composite with 0 delegates is special case: it's always ready TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { @@ -153,9 +148,7 @@ struct CompositeResultWith1Fields { return written; } - bool operator==(const CompositeResultWith1Fields& rhs) const { - return field1_ == rhs.field1_; - } + bool operator==(const CompositeResultWith1Fields& rhs) const { return field1_ == rhs.field1_; } }; typedef CompositeDeserializerWith1Delegates @@ -189,7 +182,8 @@ struct CompositeResultWith2Fields { } }; -typedef CompositeDeserializerWith2Delegates +typedef CompositeDeserializerWith2Delegates TestCompositeDeserializer2; TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { @@ -222,7 +216,8 @@ struct CompositeResultWith3Fields { } }; -typedef CompositeDeserializerWith3Delegates +typedef CompositeDeserializerWith3Delegates TestCompositeDeserializer3; TEST(CompositeDeserializerWith3Delegates, EmptyBufferShouldNotBeReady) { @@ -253,11 +248,14 @@ struct CompositeResultWith4Fields { } bool operator==(const CompositeResultWith4Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_; } }; -typedef CompositeDeserializerWith4Delegates +typedef CompositeDeserializerWith4Delegates TestCompositeDeserializer4; TEST(CompositeDeserializerWith4Delegates, EmptyBufferShouldNotBeReady) { @@ -290,11 +288,14 @@ struct CompositeResultWith5Fields { } bool operator==(const CompositeResultWith5Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_ && field5_ == rhs.field5_; } }; -typedef CompositeDeserializerWith5Delegates +typedef CompositeDeserializerWith5Delegates TestCompositeDeserializer5; TEST(CompositeDeserializerWith5Delegates, EmptyBufferShouldNotBeReady) { @@ -329,11 +330,14 @@ struct CompositeResultWith6Fields { } bool operator==(const CompositeResultWith6Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_; } }; -typedef CompositeDeserializerWith6Delegates +typedef CompositeDeserializerWith6Delegates< + CompositeResultWith6Fields, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer, StringDeserializer, StringDeserializer> TestCompositeDeserializer6; TEST(CompositeDeserializerWith6Delegates, EmptyBufferShouldNotBeReady) { @@ -370,11 +374,15 @@ struct CompositeResultWith7Fields { } bool operator==(const CompositeResultWith7Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && + field7_ == rhs.field7_; } }; -typedef CompositeDeserializerWith7Delegates +typedef CompositeDeserializerWith7Delegates< + CompositeResultWith7Fields, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer> TestCompositeDeserializer7; TEST(CompositeDeserializerWith7Delegates, EmptyBufferShouldNotBeReady) { @@ -413,11 +421,16 @@ struct CompositeResultWith8Fields { } bool operator==(const CompositeResultWith8Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_ && field8_ == rhs.field8_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && + field7_ == rhs.field7_ && field8_ == rhs.field8_; } }; -typedef CompositeDeserializerWith8Delegates +typedef CompositeDeserializerWith8Delegates< + CompositeResultWith8Fields, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer> TestCompositeDeserializer8; TEST(CompositeDeserializerWith8Delegates, EmptyBufferShouldNotBeReady) { @@ -458,11 +471,16 @@ struct CompositeResultWith9Fields { } bool operator==(const CompositeResultWith9Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && field7_ == rhs.field7_ && field8_ == rhs.field8_ && field9_ == rhs.field9_; + return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && + field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && + field7_ == rhs.field7_ && field8_ == rhs.field8_ && field9_ == rhs.field9_; } }; -typedef CompositeDeserializerWith9Delegates +typedef CompositeDeserializerWith9Delegates< + CompositeResultWith9Fields, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer, + StringDeserializer, StringDeserializer> TestCompositeDeserializer9; TEST(CompositeDeserializerWith9Delegates, EmptyBufferShouldNotBeReady) { @@ -481,4 +499,3 @@ TEST(CompositeDeserializerWith9Delegates, ShouldDeserialize) { } // namespace NetworkFilters } // namespace Extensions } // namespace Envoy -// clang-format on diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index a684e7bc649f2..05c67fe490b29 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,7 +1,7 @@ #include "common/common/stack_array.h" -#include "extensions/filters/network/kafka/generated/serialization_composite.h" #include "extensions/filters/network/kafka/serialization.h" +#include "extensions/filters/network/kafka/serialization_composite.h" #include "test/mocks/server/mocks.h" @@ -44,7 +44,14 @@ TEST(ArrayDeserializer, EmptyBufferShouldNotBeReady) { ASSERT_EQ(testee.ready(), false); } -EncodingContext encoder; +TEST(NullableArrayDeserializer, EmptyBufferShouldNotBeReady) { + // given + const NullableArrayDeserializer testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +EncodingContext encoder{-1}; // api_version does not matter for primitive types // helper function const char* getRawData(const Buffer::OwnedImpl& buffer) { @@ -272,7 +279,7 @@ TEST(NullableBytesDeserializer, ShouldThrowOnInvalidLength) { } TEST(ArrayDeserializer, ShouldConsumeCorrectAmountOfData) { - const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + const std::vector value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; serializeThenDeserializeAndCheckEquality>( value); } @@ -282,6 +289,28 @@ TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { ArrayDeserializer testee; Buffer::OwnedImpl buffer; + const int32_t len = -1; // ARRAY accepts only >= 0 + encoder.encode(len, buffer); + + uint64_t remaining = 1024; + const char* data = getRawData(buffer); + + // when + // then + EXPECT_THROW(testee.feed(data, remaining), EnvoyException); +} + +TEST(NullableArrayDeserializer, ShouldConsumeCorrectAmountOfData) { + const NullableArray value{{"aaa", "bbbbb", "cc", "d", "e", "ffffffff"}}; + serializeThenDeserializeAndCheckEquality< + NullableArrayDeserializer>(value); +} + +TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { + // given + NullableArrayDeserializer testee; + Buffer::OwnedImpl buffer; + const int32_t len = -2; // -1 is OK for ARRAY encoder.encode(len, buffer); From 00aac5b4861657375ddb28f4c15c700cb5aa21bc Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Wed, 20 Feb 2019 17:45:22 -0800 Subject: [PATCH 12/29] In case of parse errors, consume rest of request properly; apply clang-tidy fixes Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 3 + .../network/kafka/kafka_request_parser.cc | 14 +-- .../network/kafka/kafka_request_parser.h | 63 ++++++-------- .../filters/network/kafka/serialization.h | 60 ++++++------- .../network/kafka/serialization_composite.h | 62 +++++++------- .../kafka/kafka_request_parser_test.cc | 85 ++++++++++++------- 6 files changed, 155 insertions(+), 132 deletions(-) diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 7a6387ff3a0b8..c86d91c01be67 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -83,6 +83,9 @@ envoy_cc_library( hdrs = [ "message.h", ], + deps = [ + "//include/envoy/buffer:buffer_interface", + ], ) envoy_cc_library( diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc index 5cadeccb069b7..4923d3d2948e4 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -17,16 +17,16 @@ ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining } ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remaining) { + const uint64_t orig_remaining = remaining; try { context_->remaining_request_size_ -= deserializer_->feed(buffer, remaining); } catch (const EnvoyException& e) { - buffer += context_->remaining_request_size_; - remaining -= context_->remaining_request_size_; - context_->remaining_request_size_ = 0; - - const RequestHeader header{-1, -1, -1, absl::nullopt}; - return ParseResponse::parsedMessage( - std::make_shared(context_->request_header_)); + // unable to compute request header, but we still need to consume rest of request (some of the + // data might have been consumed) + const int32_t consumed = static_cast(orig_remaining - remaining); + context_->remaining_request_size_ -= consumed; + context_->request_header_ = {-1, -1, -1, absl::nullopt}; + return ParseResponse::nextParser(std::make_shared(context_)); } if (deserializer_->ready()) { diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h index 0580e705784bb..76b60d5a7b6b9 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.h +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -116,6 +116,26 @@ class RequestHeaderParser : public Parser { RequestHeaderDeserializerPtr deserializer_; }; +/** + * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported + * api_key & api_version. It does not attempt to capture any data, just throws it away until end of + * message + */ +class SentinelParser : public Parser { +public: + SentinelParser(RequestContextSharedPtr context) : context_{context} {}; + + /** + * Returns UnknownRequest + */ + ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + + const RequestContextSharedPtr contextForTest() const { return context_; } + +private: + const RequestContextSharedPtr context_; +}; + /** * Request parser uses a single deserializer to construct a request object * This parser is responsible for consuming request-specific data (e.g. topic names) and always @@ -137,13 +157,15 @@ template class RequestParser : * Fill in request's header with data stored in context */ ParseResponse parse(const char*& buffer, uint64_t& remaining) override { + const uint64_t orig_remaining = remaining; try { context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); } catch (const EnvoyException&) { // treat the whole request as invalid, throw away the rest of the data - ignoreRestOfRequest(buffer, remaining); - return ParseResponse::parsedMessage( - std::make_shared(context_->request_header_)); + const int32_t consumed = static_cast(orig_remaining - remaining); + context_->remaining_request_size_ -= + consumed; // some of the data might have been consumed by throwing deserializer + return ParseResponse::nextParser(std::make_shared(context_)); } if (deserializer.ready()) { @@ -155,47 +177,18 @@ template class RequestParser : } else { // the message makes no sense, the deserializer that matches the schema consumed all // necessary data, but there's still unconsumed bytes - ignoreRestOfRequest(buffer, remaining); - return ParseResponse::parsedMessage( - std::make_shared(context_->request_header_)); + return ParseResponse::nextParser(std::make_shared(context_)); } } else { return ParseResponse::stillWaiting(); } } + const RequestContextSharedPtr contextForTest() const { return context_; } + protected: RequestContextSharedPtr context_; DeserializerType deserializer; // underlying request-specific deserializer - -private: - // moves the pointers until the end of request, so that the data that does not match the - // deserializers gets ignored - void ignoreRestOfRequest(const char*& buffer, uint64_t& remaining) { - buffer += context_->remaining_request_size_; - remaining -= context_->remaining_request_size_; - context_->remaining_request_size_ = 0; - } -}; - -/** - * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported - * api_key & api_version. It does not attempt to capture any data, just throws it away until end of - * message - */ -class SentinelParser : public Parser { -public: - SentinelParser(RequestContextSharedPtr context) : context_{context} {}; - - /** - * Returns UnknownRequest - */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; - - const RequestContextSharedPtr contextForTest() const { return context_; } - -private: - const RequestContextSharedPtr context_; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 1b62c516f3574..438ad84122e72 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -60,7 +60,7 @@ template class IntDeserializer : public Deserializer { public: IntDeserializer() : written_{0}, ready_(false){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t available = std::min(sizeof(buf_) - written_, remaining); memcpy(buf_ + written_, buffer, available); written_ += available; @@ -75,12 +75,12 @@ template class IntDeserializer : public Deserializer { return available; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } protected: char buf_[sizeof(T) / sizeof(char)]; size_t written_; - bool ready_; + bool ready_{false}; }; /** @@ -88,7 +88,7 @@ template class IntDeserializer : public Deserializer { */ class Int8Deserializer : public IntDeserializer { public: - int8_t get() const { + int8_t get() const override { int8_t result; memcpy(&result, buf_, sizeof(result)); return result; @@ -100,7 +100,7 @@ class Int8Deserializer : public IntDeserializer { */ class Int16Deserializer : public IntDeserializer { public: - int16_t get() const { + int16_t get() const override { int16_t result; memcpy(&result, buf_, sizeof(result)); return be16toh(result); @@ -112,7 +112,7 @@ class Int16Deserializer : public IntDeserializer { */ class Int32Deserializer : public IntDeserializer { public: - int32_t get() const { + int32_t get() const override { int32_t result; memcpy(&result, buf_, sizeof(result)); return be32toh(result); @@ -124,7 +124,7 @@ class Int32Deserializer : public IntDeserializer { */ class UInt32Deserializer : public IntDeserializer { public: - uint32_t get() const { + uint32_t get() const override { uint32_t result; memcpy(&result, buf_, sizeof(result)); return be32toh(result); @@ -136,7 +136,7 @@ class UInt32Deserializer : public IntDeserializer { */ class Int64Deserializer : public IntDeserializer { public: - int64_t get() const { + int64_t get() const override { int64_t result; memcpy(&result, buf_, sizeof(result)); return be64toh(result); @@ -157,11 +157,13 @@ class BooleanDeserializer : public Deserializer { public: BooleanDeserializer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { return buffer_.feed(buffer, remaining); } + size_t feed(const char*& buffer, uint64_t& remaining) override { + return buffer_.feed(buffer, remaining); + } - bool ready() const { return buffer_.ready(); } + bool ready() const override { return buffer_.ready(); } - bool get() const { return 0 != buffer_.get(); } + bool get() const override { return 0 != buffer_.get(); } private: Int8Deserializer buffer_; @@ -181,7 +183,7 @@ class StringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer @@ -213,9 +215,9 @@ class StringDeserializer : public Deserializer { return length_consumed + data_consumed; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } - std::string get() const { return std::string(data_buf_.begin(), data_buf_.end()); } + std::string get() const override { return std::string(data_buf_.begin(), data_buf_.end()); } private: Int16Deserializer length_buf_; @@ -243,7 +245,7 @@ class NullableStringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer @@ -285,9 +287,9 @@ class NullableStringDeserializer : public Deserializer { return length_consumed + data_consumed; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } - NullableString get() const { + NullableString get() const override { return required_ >= 0 ? absl::make_optional(std::string(data_buf_.begin(), data_buf_.end())) : absl::nullopt; } @@ -316,7 +318,7 @@ class BytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer @@ -348,9 +350,9 @@ class BytesDeserializer : public Deserializer { return length_consumed + data_consumed; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } - Bytes get() const { return data_buf_; } + Bytes get() const override { return data_buf_; } private: Int32Deserializer length_buf_; @@ -376,7 +378,7 @@ class NullableBytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer @@ -418,9 +420,9 @@ class NullableBytesDeserializer : public Deserializer { return length_consumed + data_consumed; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } - NullableBytes get() const { + NullableBytes get() const override { if (NULL_BYTES_LENGTH == required_) { return absl::nullopt; } else { @@ -459,7 +461,7 @@ class ArrayDeserializer : public Deserializer> { /** * Can throw EnvoyException if array length is invalid or if DeserializerType can throw */ - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { const size_t length_consumed = length_buf_.feed(buffer, remaining); if (!length_buf_.ready()) { @@ -495,9 +497,9 @@ class ArrayDeserializer : public Deserializer> { return length_consumed + child_consumed; } - bool ready() const { return ready_; } + bool ready() const override { return ready_; } - std::vector get() const { + std::vector get() const override { std::vector result{}; result.reserve(children_.size()); for (const DeserializerType& child : children_) { @@ -536,7 +538,7 @@ class NullableArrayDeserializer : public Deserializer get() const { + NullableArray get() const override { if (NULL_ARRAY_LENGTH != required_) { std::vector result{}; result.reserve(children_.size()); diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h index 2c10291b594bd..9dd55b27ea4d6 100644 --- a/source/extensions/filters/network/kafka/serialization_composite.h +++ b/source/extensions/filters/network/kafka/serialization_composite.h @@ -44,11 +44,11 @@ class CompositeDeserializerWith0Delegates : public Deserializer { public: CompositeDeserializerWith0Delegates(){}; - size_t feed(const char*&, uint64_t&) { return 0; } + size_t feed(const char*&, uint64_t&) override { return 0; } - bool ready() const { return true; } + bool ready() const override { return true; } - ResponseType get() const { return {}; } + ResponseType get() const override { return {}; } protected: }; @@ -69,15 +69,15 @@ class CompositeDeserializerWith1Delegates : public Deserializer { public: CompositeDeserializerWith1Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); return consumed; } - bool ready() const { return delegate1_.ready(); } + bool ready() const override { return delegate1_.ready(); } - ResponseType get() const { return {delegate1_.get()}; } + ResponseType get() const override { return {delegate1_.get()}; } protected: DeserializerType1 delegate1_; @@ -100,16 +100,16 @@ class CompositeDeserializerWith2Delegates : public Deserializer { public: CompositeDeserializerWith2Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); return consumed; } - bool ready() const { return delegate2_.ready(); } + bool ready() const override { return delegate2_.ready(); } - ResponseType get() const { return {delegate1_.get(), delegate2_.get()}; } + ResponseType get() const override { return {delegate1_.get(), delegate2_.get()}; } protected: DeserializerType1 delegate1_; @@ -135,7 +135,7 @@ class CompositeDeserializerWith3Delegates : public Deserializer { public: CompositeDeserializerWith3Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -143,9 +143,11 @@ class CompositeDeserializerWith3Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate3_.ready(); } + bool ready() const override { return delegate3_.ready(); } - ResponseType get() const { return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; } + ResponseType get() const override { + return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; + } protected: DeserializerType1 delegate1_; @@ -173,7 +175,7 @@ class CompositeDeserializerWith4Delegates : public Deserializer { public: CompositeDeserializerWith4Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -182,9 +184,9 @@ class CompositeDeserializerWith4Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate4_.ready(); } + bool ready() const override { return delegate4_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; } @@ -216,7 +218,7 @@ class CompositeDeserializerWith5Delegates : public Deserializer { public: CompositeDeserializerWith5Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -226,9 +228,9 @@ class CompositeDeserializerWith5Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate5_.ready(); } + bool ready() const override { return delegate5_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), delegate5_.get()}; } @@ -264,7 +266,7 @@ class CompositeDeserializerWith6Delegates : public Deserializer { public: CompositeDeserializerWith6Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -275,9 +277,9 @@ class CompositeDeserializerWith6Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate6_.ready(); } + bool ready() const override { return delegate6_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), delegate5_.get(), delegate6_.get()}; } @@ -315,7 +317,7 @@ class CompositeDeserializerWith7Delegates : public Deserializer { public: CompositeDeserializerWith7Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -327,9 +329,9 @@ class CompositeDeserializerWith7Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate7_.ready(); } + bool ready() const override { return delegate7_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), delegate5_.get(), delegate6_.get(), delegate7_.get()}; } @@ -369,7 +371,7 @@ class CompositeDeserializerWith8Delegates : public Deserializer { public: CompositeDeserializerWith8Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -382,9 +384,9 @@ class CompositeDeserializerWith8Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate8_.ready(); } + bool ready() const override { return delegate8_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), delegate5_.get(), delegate6_.get(), delegate7_.get(), delegate8_.get()}; } @@ -427,7 +429,7 @@ class CompositeDeserializerWith9Delegates : public Deserializer { public: CompositeDeserializerWith9Delegates(){}; - size_t feed(const char*& buffer, uint64_t& remaining) { + size_t feed(const char*& buffer, uint64_t& remaining) override { size_t consumed = 0; consumed += delegate1_.feed(buffer, remaining); consumed += delegate2_.feed(buffer, remaining); @@ -441,9 +443,9 @@ class CompositeDeserializerWith9Delegates : public Deserializer { return consumed; } - bool ready() const { return delegate9_.ready(); } + bool ready() const override { return delegate9_.ready(); } - ResponseType get() const { + ResponseType get() const override { return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), delegate5_.get(), delegate6_.get(), delegate7_.get(), delegate8_.get(), delegate9_.get()}; diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index 64e886c17b279..f94e0ac3cbb45 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -15,6 +15,8 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +const int32_t FAILED_DESERIALIZER_STEP = 13; + class BufferBasedTest : public testing::Test { public: Buffer::OwnedImpl& buffer() { return buffer_; } @@ -60,7 +62,7 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { class MockParser : public Parser { public: - ParseResponse parse(const char*&, uint64_t&) { + ParseResponse parse(const char*&, uint64_t&) override { throw new EnvoyException("should not be invoked"); } }; @@ -111,11 +113,18 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur // throws during feeding class ThrowingRequestHeaderDeserializer : public RequestHeaderDeserializer { public: - size_t feed(const char*&, uint64_t&) { throw EnvoyException("feed"); }; - - bool ready() const { throw std::runtime_error("should not be invoked at all"); }; - - RequestHeader get() const { throw std::runtime_error("should not be invoked at all"); }; + size_t feed(const char*& buffer, uint64_t& remaining) override { + // move some pointers to simulate data consumption + buffer += FAILED_DESERIALIZER_STEP; + remaining -= FAILED_DESERIALIZER_STEP; + throw EnvoyException("feed"); + }; + + bool ready() const override { throw std::runtime_error("should not be invoked at all"); }; + + RequestHeader get() const override { + throw std::runtime_error("should not be invoked at all"); + }; }; const MockRequestParserResolver parser_resolver; @@ -135,26 +144,34 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur // then ASSERT_EQ(result.hasData(), true); - ASSERT_EQ(result.next_parser_, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); - ASSERT_EQ(bytes, orig_bytes + request_size); - ASSERT_EQ(remaining, orig_remaining - request_size); + ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); + ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, + request_size - FAILED_DESERIALIZER_STEP); } TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { // given + const int32_t move = FAILED_DESERIALIZER_STEP; + // throws during feeding class ThrowingDeserializer : public Deserializer { public: - size_t feed(const char*&, uint64_t&) { throw EnvoyException("feed"); }; + size_t feed(const char*& buffer, uint64_t& remaining) override { + // move some pointers to simulate data consumption + buffer += move; + remaining -= move; + throw EnvoyException("feed"); + }; - bool ready() const { throw std::runtime_error("should not be invoked at all"); }; + bool ready() const override { throw std::runtime_error("should not be invoked at all"); }; - int32_t get() const { throw std::runtime_error("should not be invoked at all"); }; + int32_t get() const override { throw std::runtime_error("should not be invoked at all"); }; }; const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request @@ -172,25 +189,28 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFee // then ASSERT_EQ(result.hasData(), true); - ASSERT_EQ(result.next_parser_, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + + ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); + ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - ASSERT_EQ(bytes, orig_bytes + request_size); - ASSERT_EQ(remaining, orig_remaining - request_size); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, + request_size - FAILED_DESERIALIZER_STEP); } -// deserializer that consumes 4 bytes and returns 0 -class FourBytesDeserializer : public Deserializer { +// deserializer that consumes FAILED_DESERIALIZER_STEP bytes and returns 0 +class SomeBytesDeserializer : public Deserializer { public: - size_t feed(const char*& buffer, uint64_t& remaining) { - buffer += 4; - remaining -= 4; - return 4; + size_t feed(const char*& buffer, uint64_t& remaining) override { + buffer += FAILED_DESERIALIZER_STEP; + remaining -= FAILED_DESERIALIZER_STEP; + return FAILED_DESERIALIZER_STEP; }; - bool ready() const { return true; }; + bool ready() const override { return true; }; - int32_t get() const { return 0; }; + int32_t get() const override { return 0; }; }; TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyButLeavingData) { @@ -198,7 +218,7 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyBut const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; - RequestParser testee{request_context}; + RequestParser testee{request_context}; const char* bytes = getBytes(); const char* orig_bytes = bytes; @@ -210,11 +230,14 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyBut // then ASSERT_EQ(result.hasData(), true); - ASSERT_EQ(result.next_parser_, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); + ASSERT_EQ(result.message_, nullptr); + + ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); + ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - ASSERT_EQ(bytes, orig_bytes + request_size); - ASSERT_EQ(remaining, orig_remaining - request_size); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, + request_size - FAILED_DESERIALIZER_STEP); } TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { From bbf0e08c0443d481bc9c5eb5fb451940d763134a Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 26 Feb 2019 10:04:05 -0800 Subject: [PATCH 13/29] Fix formatting and clang-tidy Signed-off-by: Adam Kotwasinski --- .../filters/network/kafka/kafka_types.h | 1 + .../kafka_generator.py | 690 +++++++++--------- 2 files changed, 356 insertions(+), 335 deletions(-) diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h index 4d7a6a09cd364..1aa32106cb390 100644 --- a/source/extensions/filters/network/kafka/kafka_types.h +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -2,6 +2,7 @@ #include #include +#include #include "absl/types/optional.h" diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index 3e4032c3b100b..bfa8a2d6a646c 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -9,394 +9,414 @@ # if generate-test: location of 'requests_test.cc' # INPUT_FILES: Kafka protocol json files to be processed + def main(): - import sys - import os + import sys + import os - command = sys.argv[1] - if 'generate-source' == command: - requests_h_file = os.path.abspath(sys.argv[2]) - kafka_request_resolver_cc_file = os.path.abspath(sys.argv[3]) - input_files = sys.argv[4:] - elif 'generate-test' == command: - requests_test_cc_file = os.path.abspath(sys.argv[2]) - input_files = sys.argv[3:] - else: - raise ValueError('invalid command: ' + command) + command = sys.argv[1] + if 'generate-source' == command: + requests_h_file = os.path.abspath(sys.argv[2]) + kafka_request_resolver_cc_file = os.path.abspath(sys.argv[3]) + input_files = sys.argv[4:] + elif 'generate-test' == command: + requests_test_cc_file = os.path.abspath(sys.argv[2]) + input_files = sys.argv[3:] + else: + raise ValueError('invalid command: ' + command) - import re - import json + import re + import json - requests = [] + requests = [] - for input_file in input_files: - with open(input_file, 'r') as fd: - raw_contents = fd.read() - without_comments = re.sub(r'//.*\n', '', raw_contents) - request_spec = json.loads(without_comments) - request = parse_request(request_spec) - if request is not None: # debugging - requests.append(request) + for input_file in input_files: + with open(input_file, 'r') as fd: + raw_contents = fd.read() + without_comments = re.sub(r'//.*\n', '', raw_contents) + request_spec = json.loads(without_comments) + request = parse_request(request_spec) + requests.append(request) - requests.sort(key = lambda x: x.get_extra('api_key')) + requests.sort(key=lambda x: x.get_extra('api_key')) + if 'generate-source' == command: + complex_type_template = RenderingHelper.get_template('complex_type_template.j2') + request_parsers_template = RenderingHelper.get_template('request_parser.j2') + requests_h_contents = '' - if 'generate-source' == command: - complex_type_template = RenderingHelper.get_template('complex_type_template.j2') - request_parsers_template = RenderingHelper.get_template('request_parser.j2') - requests_h_contents = '' + for request in requests: + # structures holding payload data + for dependency in request.declaration_chain: + requests_h_contents += complex_type_template.render(complex_type=dependency) + # request parser + requests_h_contents += request_parsers_template.render(complex_type=request) - for request in requests: - # structures holding payload data - for dependency in request.declaration_chain: - requests_h_contents += complex_type_template.render(complex_type = dependency) - # request parser - requests_h_contents += request_parsers_template.render(complex_type = request) + # full file with headers, namespace declaration etc. + requests_header_template = RenderingHelper.get_template('requests_h.j2') + contents = requests_header_template.render(contents=requests_h_contents) - # full file with headers, namespace declaration etc. - requests_header_template = RenderingHelper.get_template('requests_h.j2') - contents = requests_header_template.render(contents = requests_h_contents) + with open(requests_h_file, 'w') as fd: + fd.write(contents) - with open(requests_h_file, 'w') as fd: - fd.write(contents) + kafka_request_resolver_template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') + contents = kafka_request_resolver_template.render(request_types=requests) - kafka_request_resolver_template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') - contents = kafka_request_resolver_template.render(request_types = requests) + with open(kafka_request_resolver_cc_file, 'w') as fd: + fd.write(contents) - with open(kafka_request_resolver_cc_file, 'w') as fd: - fd.write(contents) + if 'generate-test' == command: + requests_test_template = RenderingHelper.get_template('requests_test_cc.j2') + contents = requests_test_template.render(request_types=requests) - if 'generate-test' == command: - requests_test_template = RenderingHelper.get_template('requests_test_cc.j2') - contents = requests_test_template.render(request_types = requests) + with open(requests_test_cc_file, 'w') as fd: + fd.write(contents) - with open(requests_test_cc_file, 'w') as fd: - fd.write(contents) def parse_request(spec): - # a request is just a complex type, that has name & versions kept in differently named fields - request_type_name = spec['name'] - request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) - return parse_complex_type(request_type_name, spec, request_versions).with_extra('api_key', spec['apiKey']) + # a request is just a complex type, that has name & versions kept in differently named fields + request_type_name = spec['name'] + request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) + return parse_complex_type(request_type_name, spec, request_versions).with_extra( + 'api_key', spec['apiKey']) + def parse_complex_type(type_name, field_spec, versions): - fields = [] - for child_field in field_spec['fields']: - child = parse_field(child_field, versions[-1]) - fields.append(child) - return Complex(type_name, fields, versions) + fields = [] + for child_field in field_spec['fields']: + child = parse_field(child_field, versions[-1]) + fields.append(child) + return Complex(type_name, fields, versions) + def parse_field(field_spec, highest_possible_version): - # obviously, field cannot be used in version higher than its type's usage - version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) - version_usage_as_nullable = Statics.parse_version_string(field_spec['nullableVersions'], highest_possible_version) if 'nullableVersions' in field_spec else range(-1) - parsed_type = parse_type(field_spec['type'], field_spec, highest_possible_version) - return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) + # obviously, field cannot be used in version higher than its type's usage + version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + version_usage_as_nullable = Statics.parse_version_string( + field_spec['nullableVersions'], + highest_possible_version) if 'nullableVersions' in field_spec else range(-1) + parsed_type = parse_type(field_spec['type'], field_spec, highest_possible_version) + return FieldSpec(field_spec['name'], parsed_type, version_usage, version_usage_as_nullable) + def parse_type(type_name, field_spec, highest_possible_version): - # array types are defined as `[]underlying_type` instead of having its own element with type inside :\ - if (type_name.startswith('[]')): - underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) - return Array(underlying_type) - else: - if (type_name in Primitive.PRIMITIVE_TYPE_NAMES): - return Primitive(type_name, field_spec.get('default')) - else: - versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) - return parse_complex_type(type_name, field_spec, versions) + # array types are defined as `[]underlying_type` instead of having its own element with type inside :\ + if (type_name.startswith('[]')): + underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) + return Array(underlying_type) + else: + if (type_name in Primitive.PRIMITIVE_TYPE_NAMES): + return Primitive(type_name, field_spec.get('default')) + else: + versions = Statics.parse_version_string(field_spec['versions'], highest_possible_version) + return parse_complex_type(type_name, field_spec, versions) + class Statics: - @staticmethod - def parse_version_string(raw_versions, highest_possible_version): - if raw_versions.endswith('+'): - return range(int(raw_versions[:-1]), highest_possible_version + 1) - else: - if '-' in raw_versions: - tokens = raw_versions.split('-', 1) - return range(int(tokens[0]), int(tokens[1]) + 1) - else: - single_version = int(raw_versions) - return range(single_version, single_version + 1) + @staticmethod + def parse_version_string(raw_versions, highest_possible_version): + if raw_versions.endswith('+'): + return range(int(raw_versions[:-1]), highest_possible_version + 1) + else: + if '-' in raw_versions: + tokens = raw_versions.split('-', 1) + return range(int(tokens[0]), int(tokens[1]) + 1) + else: + single_version = int(raw_versions) + return range(single_version, single_version + 1) + class FieldList: - def __init__(self, version, fields): - self.version = version - self.fields = fields - - def used_fields(self): - return filter(lambda x: x.used_in_version(self.version), self.fields) - - def constructor_signature(self): - parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) - return ', '.join(parameter_spec) - - def constructor_init_list(self): - init_list = [] - for field in self.fields: - if field.used_in_version(self.version): - if field.is_nullable(): - if field.is_nullable_in_version(self.version): - # field is optional, and the parameter is optional in this version - init_list_item = '%s_{%s}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # field is optional, and the parameter is T in this version - init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # field is T, so parameter cannot be optional - init_list_item = '%s_{%s}' % (field.name, field.name) - init_list.append(init_list_item) - else: - # field is not used in this version, so we need to put in default value - init_list_item = '%s_{%s}' % (field.name, field.default_value()) - init_list.append(init_list_item) - pass - return ', '.join(init_list) - - def field_count(self): - return len(self.used_fields()) - - def example_value(self): - return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) + def __init__(self, version, fields): + self.version = version + self.fields = fields + + def used_fields(self): + return filter(lambda x: x.used_in_version(self.version), self.fields) + + def constructor_signature(self): + parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) + return ', '.join(parameter_spec) + + def constructor_init_list(self): + init_list = [] + for field in self.fields: + if field.used_in_version(self.version): + if field.is_nullable(): + if field.is_nullable_in_version(self.version): + # field is optional, and the parameter is optional in this version + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is optional, and the parameter is T in this version + init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is T, so parameter cannot be optional + init_list_item = '%s_{%s}' % (field.name, field.name) + init_list.append(init_list_item) + else: + # field is not used in this version, so we need to put in default value + init_list_item = '%s_{%s}' % (field.name, field.default_value()) + init_list.append(init_list_item) + pass + return ', '.join(init_list) + + def field_count(self): + return len(self.used_fields()) + + def example_value(self): + return ', '.join(map(lambda x: x.example_value_for_test(self.version), self.used_fields())) + class FieldSpec: - def __init__(self, name, type, version_usage, version_usage_as_nullable): - import re - separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) - self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() - self.type = type - self.version_usage = version_usage - self.version_usage_as_nullable = version_usage_as_nullable - - def is_nullable(self): - return len(self.version_usage_as_nullable) > 0 - - def is_nullable_in_version(self, version): - return version in self.version_usage_as_nullable - - def used_in_version(self, version): - return version in self.version_usage - - def field_declaration(self): - if self.is_nullable(): - return 'absl::optional<%s> %s' % (self.type.name, self.name) - else: - return '%s %s' % (self.type.name, self.name) - - def parameter_declaration(self, version): - if self.is_nullable_in_version(version): - return 'absl::optional<%s> %s' % (self.type.name, self.name) - else: - return '%s %s' % (self.type.name, self.name) - - def default_value(self): - if self.is_nullable(): - return '{%s}' % self.type.default_value() - else: - return str(self.type.default_value()) - - def example_value_for_test(self, version): - if self.is_nullable(): - return 'absl::make_optional<%s>(%s)' % (self.type.name, self.type.example_value_for_test(version)) - else: - return str(self.type.example_value_for_test(version)) - - def deserializer_name_in_version(self, version): - if self.is_nullable_in_version(version): - return 'Nullable%s' % self.type.deserializer_name_in_version(version) - else: - return self.type.deserializer_name_in_version(version) - - def is_printable(self): - return self.type.is_printable() + def __init__(self, name, type, version_usage, version_usage_as_nullable): + import re + separated = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + self.name = re.sub('([a-z0-9])([A-Z])', r'\1_\2', separated).lower() + self.type = type + self.version_usage = version_usage + self.version_usage_as_nullable = version_usage_as_nullable + + def is_nullable(self): + return len(self.version_usage_as_nullable) > 0 + + def is_nullable_in_version(self, version): + return version in self.version_usage_as_nullable + + def used_in_version(self, version): + return version in self.version_usage + + def field_declaration(self): + if self.is_nullable(): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def parameter_declaration(self, version): + if self.is_nullable_in_version(version): + return 'absl::optional<%s> %s' % (self.type.name, self.name) + else: + return '%s %s' % (self.type.name, self.name) + + def default_value(self): + if self.is_nullable(): + return '{%s}' % self.type.default_value() + else: + return str(self.type.default_value()) + + def example_value_for_test(self, version): + if self.is_nullable(): + return 'absl::make_optional<%s>(%s)' % (self.type.name, + self.type.example_value_for_test(version)) + else: + return str(self.type.example_value_for_test(version)) + + def deserializer_name_in_version(self, version): + if self.is_nullable_in_version(version): + return 'Nullable%s' % self.type.deserializer_name_in_version(version) + else: + return self.type.deserializer_name_in_version(version) + + def is_printable(self): + return self.type.is_printable() + class TypeSpecification: - def deserializer_name_in_version(self, version): - raise NotImplementedError() + def deserializer_name_in_version(self, version): + raise NotImplementedError() + + def default_value(self): + raise NotImplementedError() - def default_value(self): - raise NotImplementedError() + def example_value_for_test(self, version): + raise NotImplementedError() - def example_value_for_test(self, version): - raise NotImplementedError() + def is_printable(self): + raise NotImplementedError() - def is_printable(self): - raise NotImplementedError() class Array(TypeSpecification): - - def __init__(self, underlying): - self.underlying = underlying - self.declaration_chain = self.underlying.declaration_chain - @property - def name(self): - return 'std::vector<%s>' % self.underlying.name + def __init__(self, underlying): + self.underlying = underlying + self.declaration_chain = self.underlying.declaration_chain + + @property + def name(self): + return 'std::vector<%s>' % self.underlying.name + + def deserializer_name_in_version(self, version): + return 'ArrayDeserializer<%s, %s>' % (self.underlying.name, + self.underlying.deserializer_name_in_version(version)) - def deserializer_name_in_version(self, version): - return 'ArrayDeserializer<%s, %s>' % (self.underlying.name, self.underlying.deserializer_name_in_version(version) ) + def default_value(self): + return '{}' - def default_value(self): - return '{}' + def example_value_for_test(self, version): + return 'std::vector<%s>{ %s }' % (self.underlying.name, + self.underlying.example_value_for_test(version)) - def example_value_for_test(self, version): - return 'std::vector<%s>{ %s }' % (self.underlying.name, self.underlying.example_value_for_test(version)) + def is_printable(self): + return self.underlying.is_printable() - def is_printable(self): - return self.underlying.is_printable() class Primitive(TypeSpecification): - PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] - - KAFKA_TYPE_TO_ENVOY_TYPE = { - 'string': 'std::string', - 'bool': 'bool', - 'int8': 'int8_t', - 'int16': 'int16_t', - 'int32': 'int32_t', - 'int64': 'int64_t', - 'bytes': 'Bytes', - } - - KAFKA_TYPE_TO_DESERIALIZER = { - 'string': 'StringDeserializer', - 'bool': 'BooleanDeserializer', - 'int8': 'Int8Deserializer', - 'int16': 'Int16Deserializer', - 'int32': 'Int32Deserializer', - 'int64': 'Int64Deserializer', - 'bytes': 'BytesDeserializer', - } - - # https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages - KAFKA_TYPE_TO_DEFAULT_VALUE = { - 'string': '""', - 'bool': 'false', - 'int8': '0', - 'int16': '0', - 'int32': '0', - 'int64': '0', - 'bytes': '{}', - } - - # to make test code more readable - KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { - 'string': '"string"', - 'bool': 'false', - 'int8': '8', - 'int16': '16', - 'int32': '32', - 'int64': '64ll', - 'bytes': 'Bytes({0, 1, 2, 3})', - } - - def __init__(self, name, custom_default_value): - self.original_name = name - self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) - self.custom_default_value = custom_default_value - self.declaration_chain = [] - self.deserializer_name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) - - @staticmethod - def compute(name, map): - if name in map: - return map[name] - else: - raise ValueError(name) - - def deserializer_name_in_version(self, version): - return self.deserializer_name - - def default_value(self): - if self.custom_default_value is not None: - return self.custom_default_value - else: - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) - - def example_value_for_test(self, version): - return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) - - def is_printable(self): - return self.name not in ['Bytes'] + PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] + + KAFKA_TYPE_TO_ENVOY_TYPE = { + 'string': 'std::string', + 'bool': 'bool', + 'int8': 'int8_t', + 'int16': 'int16_t', + 'int32': 'int32_t', + 'int64': 'int64_t', + 'bytes': 'Bytes', + } + + KAFKA_TYPE_TO_DESERIALIZER = { + 'string': 'StringDeserializer', + 'bool': 'BooleanDeserializer', + 'int8': 'Int8Deserializer', + 'int16': 'Int16Deserializer', + 'int32': 'Int32Deserializer', + 'int64': 'Int64Deserializer', + 'bytes': 'BytesDeserializer', + } + + # https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages + KAFKA_TYPE_TO_DEFAULT_VALUE = { + 'string': '""', + 'bool': 'false', + 'int8': '0', + 'int16': '0', + 'int32': '0', + 'int64': '0', + 'bytes': '{}', + } + + # to make test code more readable + KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { + 'string': '"string"', + 'bool': 'false', + 'int8': '8', + 'int16': '16', + 'int32': '32', + 'int64': '64ll', + 'bytes': 'Bytes({0, 1, 2, 3})', + } + + def __init__(self, name, custom_default_value): + self.original_name = name + self.name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_ENVOY_TYPE) + self.custom_default_value = custom_default_value + self.declaration_chain = [] + self.deserializer_name = Primitive.compute(name, Primitive.KAFKA_TYPE_TO_DESERIALIZER) + + @staticmethod + def compute(name, map): + if name in map: + return map[name] + else: + raise ValueError(name) + + def deserializer_name_in_version(self, version): + return self.deserializer_name + + def default_value(self): + if self.custom_default_value is not None: + return self.custom_default_value + else: + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_DEFAULT_VALUE) + + def example_value_for_test(self, version): + return Primitive.compute(self.original_name, Primitive.KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST) + + def is_printable(self): + return self.name not in ['Bytes'] + class Complex(TypeSpecification): - def __init__(self, name, fields, versions): - self.name = name - self.fields = fields - self.versions = versions - self.declaration_chain = self.__compute_declaration_chain() - self.attributes = {} - - def __compute_declaration_chain(self): - result = [] - for field in self.fields: - result.extend(field.type.declaration_chain) - result.append(self) - return result - - def with_extra(self, key, value): - self.attributes[key] = value - return self - - def get_extra(self, key): - return self.attributes[key] - - def compute_constructors(self): - # field lists for different versions may not differ (as Kafka can bump version without any changes) - # but constructors need to be unique - signature_to_constructor = {} - for field_list in self.compute_field_lists(): - signature = field_list.constructor_signature() - constructor = signature_to_constructor.get(signature) - if constructor is None: - entry = {} - entry['versions'] = [ field_list.version ] - entry['signature'] = signature - if (len(signature) > 0): - entry['full_declaration'] = '%s(%s): %s {};' % (self.name, signature, field_list.constructor_init_list()) - else: - entry['full_declaration'] = '%s() {};' % self.name - signature_to_constructor[signature] = entry - else: - constructor['versions'].append(field_list.version) - return sorted(signature_to_constructor.values(), key = lambda x: x['versions'][0]) - - def compute_field_lists(self): - field_lists = [] - for version in self.versions: - field_list = FieldList(version, self.fields) - field_lists.append(field_list) - return field_lists; - - def deserializer_name_in_version(self, version): - return '%sV%dDeserializer' % (self.name, version) - - def default_value(self): - raise NotImplementedError('unable to create default value of complex type') - - def example_value_for_test(self, version): - field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) - example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) - return '%s(%s)' % (self.name, ', '.join(example_values)) - - def is_printable(self): - return True + def __init__(self, name, fields, versions): + self.name = name + self.fields = fields + self.versions = versions + self.declaration_chain = self.__compute_declaration_chain() + self.attributes = {} + + def __compute_declaration_chain(self): + result = [] + for field in self.fields: + result.extend(field.type.declaration_chain) + result.append(self) + return result + + def with_extra(self, key, value): + self.attributes[key] = value + return self + + def get_extra(self, key): + return self.attributes[key] + + def compute_constructors(self): + # field lists for different versions may not differ (as Kafka can bump version without any changes) + # but constructors need to be unique + signature_to_constructor = {} + for field_list in self.compute_field_lists(): + signature = field_list.constructor_signature() + constructor = signature_to_constructor.get(signature) + if constructor is None: + entry = {} + entry['versions'] = [field_list.version] + entry['signature'] = signature + if (len(signature) > 0): + entry['full_declaration'] = '%s(%s): %s {};' % (self.name, signature, + field_list.constructor_init_list()) + else: + entry['full_declaration'] = '%s() {};' % self.name + signature_to_constructor[signature] = entry + else: + constructor['versions'].append(field_list.version) + return sorted(signature_to_constructor.values(), key=lambda x: x['versions'][0]) + + def compute_field_lists(self): + field_lists = [] + for version in self.versions: + field_list = FieldList(version, self.fields) + field_lists.append(field_list) + return field_lists + + def deserializer_name_in_version(self, version): + return '%sV%dDeserializer' % (self.name, version) + + def default_value(self): + raise NotImplementedError('unable to create default value of complex type') + + def example_value_for_test(self, version): + field_list = next(fl for fl in self.compute_field_lists() if fl.version == version) + example_values = map(lambda x: x.example_value_for_test(version), field_list.used_fields()) + return '%s(%s)' % (self.name, ', '.join(example_values)) + + def is_printable(self): + return True + class RenderingHelper: - @staticmethod - def get_template(template): - import jinja2 - import os - env = jinja2.Environment(loader = jinja2.FileSystemLoader(searchpath = os.path.dirname(os.path.abspath(__file__)))) - return env.get_template(template) + @staticmethod + def get_template(template): + import jinja2 + import os + env = jinja2.Environment( + loader=jinja2.FileSystemLoader(searchpath=os.path.dirname(os.path.abspath(__file__)))) + return env.get_template(template) + if __name__ == "__main__": - main() + main() From 76cc44ffe12a9223d4f45c60a1d0721121865aea Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 26 Feb 2019 12:19:44 -0800 Subject: [PATCH 14/29] Fix spelling Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/kafka_request.h | 4 ++-- source/extensions/filters/network/kafka/request_codec.cc | 2 +- source/extensions/filters/network/kafka/serialization.h | 6 +++--- .../filters/network/kafka/serialization_composite.h | 3 +-- .../filters/network/kafka/serialization_composite_test.cc | 3 +-- test/extensions/filters/network/kafka/serialization_test.cc | 2 +- tools/spelling_dictionary.txt | 4 ++++ 7 files changed, 13 insertions(+), 11 deletions(-) diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index fbd844e695e0b..5deffc4d91abf 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -86,8 +86,8 @@ class UnknownRequest : public AbstractRequest { UnknownRequest(const RequestHeader& request_header) : AbstractRequest{request_header} {}; // this isn't the prettiest, as we have thrown away the data - // XXX(adam.kotwasinski) discuss capturing the data as-is, and simply putting it back - // this would add ability to forward unknown types of requests in cluster-proxy + // XXX discuss capturing the data as-is, and simply putting it back + // this would add ability to forward unknown types of requests in cluster-proxy size_t encode(Buffer::Instance&) const override { throw EnvoyException("cannot serialize unknown request"); } diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index cc8c799749bcf..e23141c173ebe 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -55,7 +55,7 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl void MessageEncoderImpl::encode(const Message& message) { Buffer::OwnedImpl data_buffer; - // TODO (adam.kotwasinski) precompute the size instead of using temporary + // TODO precompute the size instead of using temporary // also, when we have 'computeSize' method, then we can push encoding request's size into // Request::encode int32_t data_len = message.encode(data_buffer); // encode data computing data length diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 438ad84122e72..8f3fcd7e865ca 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -613,10 +613,10 @@ class NullableArrayDeserializer : public Deserializer& arg, Buffer::Instance& ds size_t written{0}; for (const T& el : *arg) { // for each of array elements, resolve the correct method again - // elements could be primitives or complex types, so calling `el.encode()` won't work + // elements could be primitives or complex types, so calling encode() on object won't work written += encode(el, dst); } return header_length + written; diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h index 9dd55b27ea4d6..15e4c5cd7caa0 100644 --- a/source/extensions/filters/network/kafka/serialization_composite.h +++ b/source/extensions/filters/network/kafka/serialization_composite.h @@ -1,5 +1,4 @@ -// FIXME(adam.kotwasinski) this file can be generated, as it's repeating the same code for 0..9 -// delegates +// XXX this file can be generated, as it's repeating the same code for 0..9 delegates #pragma once #include diff --git a/test/extensions/filters/network/kafka/serialization_composite_test.cc b/test/extensions/filters/network/kafka/serialization_composite_test.cc index d406ae214b63d..c3ff041a5e62a 100644 --- a/test/extensions/filters/network/kafka/serialization_composite_test.cc +++ b/test/extensions/filters/network/kafka/serialization_composite_test.cc @@ -1,5 +1,4 @@ -// FIXME(adam.kotwasinski) this file can be generated, as it's repeating the same code for 0..9 -// delegates +// XXX this file can be generated, as it's repeating the same code for 0..9 delegates #include "common/common/stack_array.h" #include "extensions/filters/network/kafka/serialization.h" diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 05c67fe490b29..2d43e1d3cbc06 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -146,7 +146,7 @@ template void serializeThenDeserializeAndCheckEqualit serializeThenDeserializeAndCheckEqualityWithChunks(expected); } -// macroed out test for numeric buffers +// extracted test for numeric buffers #define TEST_DeserializerShouldDeserialize(BufferClass, DataClass, Value) \ TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ /* given */ \ diff --git a/tools/spelling_dictionary.txt b/tools/spelling_dictionary.txt index 4090d5064de2e..be7d2d8f68160 100644 --- a/tools/spelling_dictionary.txt +++ b/tools/spelling_dictionary.txt @@ -366,7 +366,9 @@ dereferencing deregistered deserialization deserialize +deserialized deserializer +deserializers dest destructor destructors @@ -556,6 +558,7 @@ parameterizing params paren parentid +parsers pcall pcap pclose @@ -676,6 +679,7 @@ templating templatize templatized templatizing +testee th thru tm From 6f86d764275bcf1f3f7303e260603d4ab86084fe Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Mon, 11 Mar 2019 13:44:01 -0700 Subject: [PATCH 15/29] Fixes after review: string_view used instead of raw pointers; documentation Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 29 +- .../extensions/filters/network/kafka/codec.h | 3 +- .../filters/network/kafka/kafka_request.h | 2 +- .../network/kafka/kafka_request_parser.cc | 23 +- .../network/kafka/kafka_request_parser.h | 25 +- .../extensions/filters/network/kafka/parser.h | 7 +- .../complex_type_template.j2 | 21 + .../kafka_generator.py | 125 ++++- .../kafka_request_resolver_cc.j2 | 16 +- .../protocol_code_generator/request_parser.j2 | 8 + .../protocol_code_generator/requests_h.j2 | 25 +- .../requests_test_cc.j2 | 20 +- .../filters/network/kafka/request_codec.cc | 15 +- .../filters/network/kafka/serialization.h | 85 ++- .../serialization_composite_generator.py | 76 +++ .../serialization_composite_h.j2 | 98 ++++ .../serialization_composite_test_cc.j2 | 87 +++ .../network/kafka/serialization_composite.h | 468 ---------------- test/extensions/filters/network/kafka/BUILD | 33 +- .../kafka/kafka_request_parser_test.cc | 115 ++-- .../network/kafka/request_codec_test.cc | 2 +- .../kafka/serialization_composite_test.cc | 500 ------------------ .../network/kafka/serialization_test.cc | 139 +---- .../network/kafka/serialization_utilities.h | 120 +++++ 24 files changed, 768 insertions(+), 1274 deletions(-) create mode 100755 source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py create mode 100644 source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 create mode 100644 source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 delete mode 100644 source/extensions/filters/network/kafka/serialization_composite.h delete mode 100644 test/extensions/filters/network/kafka/serialization_composite_test.cc create mode 100644 test/extensions/filters/network/kafka/serialization_utilities.h diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index c86d91c01be67..c3a63cd0ccf03 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -54,7 +54,11 @@ genrule( "requests.h", "kafka_request_resolver.cc", ], - cmd = "./$(location :kafka_code_generator) generate-source $(location requests.h) $(location kafka_request_resolver.cc) $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file)", + cmd = """ + ./$(location :kafka_code_generator) generate-source \ + $(location requests.h) $(location kafka_request_resolver.cc) \ + $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file) + """, tools = [ ":kafka_code_generator", ], @@ -101,6 +105,29 @@ envoy_cc_library( ], ) +genrule( + name = "serialization_composite_generated_source", + srcs = [], + outs = [ + "serialization_composite.h", + ], + cmd = """ + ./$(location :serialization_composite_generator) generate-source \ + $(location serialization_composite.h) + """, + tools = [ + ":serialization_composite_generator", + ], +) + +py_binary( + name = "serialization_composite_generator", + srcs = ["serialization_code_generator/serialization_composite_generator.py"], + data = glob(["serialization_code_generator/*.j2"]), + main = "serialization_code_generator/serialization_composite_generator.py", + deps = ["@com_github_pallets_jinja//:jinja2"], +) + envoy_cc_library( name = "kafka_protocol_lib", hdrs = [ diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index cfbd1c5337480..8aabd4d620d50 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -25,8 +25,7 @@ class MessageDecoder { }; /** - * Kafka message decoder - * @tparam MessageType message type (Kafka request or Kafka response) + * Kafka message encoder */ class MessageEncoder { public: diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 5deffc4d91abf..eaa6991b35af8 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -86,7 +86,7 @@ class UnknownRequest : public AbstractRequest { UnknownRequest(const RequestHeader& request_header) : AbstractRequest{request_header} {}; // this isn't the prettiest, as we have thrown away the data - // XXX discuss capturing the data as-is, and simply putting it back + // TODO(adamkotwasinski) discuss capturing the data as-is, and simply putting it back // this would add ability to forward unknown types of requests in cluster-proxy size_t encode(Buffer::Instance&) const override { throw EnvoyException("cannot serialize unknown request"); diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc index 4923d3d2948e4..d797b8dc07ca3 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -5,8 +5,12 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining) { - request_length_.feed(buffer, remaining); +const RequestParserResolver& RequestParserResolver::getDefaultInstance() { + CONSTRUCT_ON_FIRST_USE(RequestParserResolver); +} + +ParseResponse RequestStartParser::parse(absl::string_view& data) { + request_length_.feed(data); if (request_length_.ready()) { context_->remaining_request_size_ = request_length_.get(); return ParseResponse::nextParser( @@ -16,14 +20,14 @@ ParseResponse RequestStartParser::parse(const char*& buffer, uint64_t& remaining } } -ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remaining) { - const uint64_t orig_remaining = remaining; +ParseResponse RequestHeaderParser::parse(absl::string_view& data) { + const absl::string_view orig_data = data; try { - context_->remaining_request_size_ -= deserializer_->feed(buffer, remaining); + context_->remaining_request_size_ -= deserializer_->feed(data); } catch (const EnvoyException& e) { // unable to compute request header, but we still need to consume rest of request (some of the // data might have been consumed) - const int32_t consumed = static_cast(orig_remaining - remaining); + const int32_t consumed = static_cast(orig_data.size() - data.size()); context_->remaining_request_size_ -= consumed; context_->request_header_ = {-1, -1, -1, absl::nullopt}; return ParseResponse::nextParser(std::make_shared(context_)); @@ -40,10 +44,9 @@ ParseResponse RequestHeaderParser::parse(const char*& buffer, uint64_t& remainin } } -ParseResponse SentinelParser::parse(const char*& buffer, uint64_t& remaining) { - const size_t min = std::min(context_->remaining_request_size_, remaining); - buffer += min; - remaining -= min; +ParseResponse SentinelParser::parse(absl::string_view& data) { + const size_t min = std::min(context_->remaining_request_size_, data.size()); + data = {data.data() + min, data.size() - min}; context_->remaining_request_size_ -= min; if (0 == context_->remaining_request_size_) { return ParseResponse::parsedMessage( diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h index 76b60d5a7b6b9..cc29f8a655d4e 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.h +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -44,9 +44,9 @@ class RequestParserResolver { RequestContextSharedPtr context) const; /** - * Request parser singleton + * Return default resolver, that uses request's api key and version to provide a matching parser */ - static const RequestParserResolver INSTANCE; + static const RequestParserResolver& getDefaultInstance(); }; /** @@ -59,10 +59,10 @@ class RequestStartParser : public Parser { : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; /** - * Consumes INT32 bytes as request length and updates the context with that value + * Consumes 4 bytes (INT32) as request length and updates the context with that value * @return RequestHeaderParser instance to process request header */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + ParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -106,7 +106,7 @@ class RequestHeaderParser : public Parser { * Uses data provided to compute request header * @return Parser instance responsible for processing rest of the message */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + ParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -128,7 +128,7 @@ class SentinelParser : public Parser { /** * Returns UnknownRequest */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override; + ParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -156,17 +156,8 @@ template class RequestParser : * Consume enough data to fill in deserializer and receive the parsed request * Fill in request's header with data stored in context */ - ParseResponse parse(const char*& buffer, uint64_t& remaining) override { - const uint64_t orig_remaining = remaining; - try { - context_->remaining_request_size_ -= deserializer.feed(buffer, remaining); - } catch (const EnvoyException&) { - // treat the whole request as invalid, throw away the rest of the data - const int32_t consumed = static_cast(orig_remaining - remaining); - context_->remaining_request_size_ -= - consumed; // some of the data might have been consumed by throwing deserializer - return ParseResponse::nextParser(std::make_shared(context_)); - } + ParseResponse parse(absl::string_view& data) override { + context_->remaining_request_size_ -= deserializer.feed(data); if (deserializer.ready()) { if (0 == context_->remaining_request_size_) { diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h index 8ceade142b6b5..9cef1eb19122f 100644 --- a/source/extensions/filters/network/kafka/parser.h +++ b/source/extensions/filters/network/kafka/parser.h @@ -7,6 +7,8 @@ #include "extensions/filters/network/kafka/kafka_types.h" #include "extensions/filters/network/kafka/message.h" +#include "absl/strings/string_view.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -27,11 +29,10 @@ class Parser : public Logger::Loggable { /** * Submit data to be processed by parser, will consume as much data as it is necessary to reach * the conclusion what should be the next parse step - * @param buffer data pointer, will be updated by parser - * @param remaining remaining data in buffer, will be updated by parser + * @param data bytes to be processed, will be updated by parser if any have been consumed * @return parse status - decision what should be done with current parser (keep/replace) */ - virtual ParseResponse parse(const char*& buffer, uint64_t& remaining) PURE; + virtual ParseResponse parse(absl::string_view& data) PURE; }; typedef std::shared_ptr ParserSharedPtr; diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 index aaf4ebce48633..289080b08749b 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 @@ -1,10 +1,27 @@ +{# + Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest, FetchRequestTopic, FetchRequestPartition) + Rendered templates for each structure in Kafka protocol will be put into 'requests.h' file + + Each structure is capable of holding all versions of given entity (what means its fields are actually a superset of union of all versions' fields) + Each version has a dedicated deserializer (named $requestV$versionDeserializer), which calls the matching constructor + + To serialize, it is necessary to pass the encoding context (that contains the version that's being serialized) + Depending on the version, the fields will be written to the buffer +#} struct {{ complex_type.name }} { + + {# + Constructors invoked by deserializers + Each constructor has a signature that matches the fields in at least one version + (sometimes there are different Kafka versions that are actually composed of precisely the same fields) + #} {% for field in complex_type.fields %} const {{ field.field_declaration() }}_;{% endfor %} {% for constructor in complex_type.compute_constructors() %} // constructor used in versions: {{ constructor['versions'] }} {{ constructor['full_declaration'] }}{% endfor %} + {# For every field that's used in version, just serialize it #} {% if complex_type.fields|length > 0 %} size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { const int16_t api_version = encoder.apiVersion(); @@ -30,6 +47,10 @@ struct {{ complex_type.name }} { }; }; + +{# + Each structure version has a deserializer that matches the structure's field list +#} {% for field_list in complex_type.compute_field_lists() %} class {{ complex_type.name }}V{{ field_list.version }}Deserializer: public CompositeDeserializerWith{{ field_list.field_count() }}Delegates<{{ complex_type.name }}{% for field in field_list.used_fields() %}, {{ field.deserializer_name_in_version(field_list.version) }}{% endfor %}>{}; diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index bfa8a2d6a646c..ae5396d6531b5 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -1,16 +1,35 @@ #!/usr/bin/python -# usage: -# kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES -# where: -# COMMAND : 'generate-source', to generate source files -# 'generate-test', to generate test files -# OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', -# if generate-test: location of 'requests_test.cc' -# INPUT_FILES: Kafka protocol json files to be processed - def main(): + """ + Kafka header generator script + ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ + Generates C++ headers from Kafka protocol specification + Can generate both main source code, as well as test code + + Usage: + kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES + where: + COMMAND : 'generate-source', to generate source files + 'generate-test', to generate test files + OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', + if generate-test: location of 'requests_test.cc' + INPUT_FILES: Kafka protocol json files to be processed + + Kafka spec files are provided at https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message and in Kafka clients jar file + When generating source code, it creates: + - requests.h - definition of all the structures/deserializers/parsers related to Kafka requests + - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from requests.h + When generating test code, it creates: + - requests_test.cc - serialization/deserialization tests for kafka structures + + Templates used are: + - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2 + - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2 + - to create 'requests_test.cc': requests_test_cc.j2 + """ + import sys import os @@ -30,6 +49,7 @@ def main(): requests = [] + # for each request spec file, remove comments, and parse the remains for input_file in input_files: with open(input_file, 'r') as fd: raw_contents = fd.read() @@ -38,18 +58,21 @@ def main(): request = parse_request(request_spec) requests.append(request) + # sort requests by api_key requests.sort(key=lambda x: x.get_extra('api_key')) + # main source code if 'generate-source' == command: complex_type_template = RenderingHelper.get_template('complex_type_template.j2') request_parsers_template = RenderingHelper.get_template('request_parser.j2') + requests_h_contents = '' for request in requests: - # structures holding payload data + # for each structure that is used by request, render its corresponding structures for dependency in request.declaration_chain: requests_h_contents += complex_type_template.render(complex_type=dependency) - # request parser + # each top-level structure (e.g. FetchRequest) is going to have corresponding parsers requests_h_contents += request_parsers_template.render(complex_type=request) # full file with headers, namespace declaration etc. @@ -65,6 +88,7 @@ def main(): with open(kafka_request_resolver_cc_file, 'w') as fd: fd.write(contents) + # test code if 'generate-test' == command: requests_test_template = RenderingHelper.get_template('requests_test_cc.j2') contents = requests_test_template.render(request_types=requests) @@ -74,7 +98,10 @@ def main(): def parse_request(spec): - # a request is just a complex type, that has name & versions kept in differently named fields + """ + Parse a given structure into a request + Request is just a complex type, that has name & versions kept in differently named fields + """ request_type_name = spec['name'] request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) return parse_complex_type(request_type_name, spec, request_versions).with_extra( @@ -82,6 +109,9 @@ def parse_request(spec): def parse_complex_type(type_name, field_spec, versions): + """ + Parse given complex type, returning a structure that holds its name, field specification and allowed versions + """ fields = [] for child_field in field_spec['fields']: child = parse_field(child_field, versions[-1]) @@ -90,7 +120,10 @@ def parse_complex_type(type_name, field_spec, versions): def parse_field(field_spec, highest_possible_version): - # obviously, field cannot be used in version higher than its type's usage + """ + Parse given field, returning a structure holding the name, type, and versions when this field is actually used (nullable or not) + Obviously, field cannot be used in version higher than its type's usage + """ version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) version_usage_as_nullable = Statics.parse_version_string( field_spec['nullableVersions'], @@ -100,8 +133,11 @@ def parse_field(field_spec, highest_possible_version): def parse_type(type_name, field_spec, highest_possible_version): - # array types are defined as `[]underlying_type` instead of having its own element with type inside :\ + """ + Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one (== struct) + """ if (type_name.startswith('[]')): + # in spec files, array types are defined as `[]underlying_type` instead of having its own element with type inside :\ underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) return Array(underlying_type) else: @@ -116,6 +152,9 @@ class Statics: @staticmethod def parse_version_string(raw_versions, highest_possible_version): + """ + Return integer range that corresponds to version string in spec file + """ if raw_versions.endswith('+'): return range(int(raw_versions[:-1]), highest_possible_version + 1) else: @@ -128,19 +167,34 @@ def parse_version_string(raw_versions, highest_possible_version): class FieldList: + """ + List of fields used by given entity (request or child structure) in given request version + (as fields get added/removed across versions) + """ def __init__(self, version, fields): self.version = version self.fields = fields def used_fields(self): + """ + Return list of fields that are actually used in this version of structure + """ return filter(lambda x: x.used_in_version(self.version), self.fields) def constructor_signature(self): + """ + Return constructor signature + Mutliple versions of the same structure can have identical signatures (due to version bumps in Kafka) + """ parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) return ', '.join(parameter_spec) def constructor_init_list(self): + """ + Renders member initialization list in constructor + Takes care of potential optional conversions (as field could be T in V1, but optional in V2) + """ init_list = [] for field in self.fields: if field.used_in_version(self.version): @@ -172,6 +226,10 @@ def example_value(self): class FieldSpec: + """ + Represents a field present in a structure (request, or child structure thereof) + Contains name, type, and versions when it is used (nullable or not) + """ def __init__(self, name, type, version_usage, version_usage_as_nullable): import re @@ -185,6 +243,11 @@ def is_nullable(self): return len(self.version_usage_as_nullable) > 0 def is_nullable_in_version(self, version): + """ + Whether thie field is nullable in given version + Fields can be non-nullable in earlier versions + See https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message#nullable-fields + """ return version in self.version_usage_as_nullable def used_in_version(self, version): @@ -228,9 +291,15 @@ def is_printable(self): class TypeSpecification: def deserializer_name_in_version(self, version): + """ + Renders the deserializer name of given type, in request with given version + """ raise NotImplementedError() def default_value(self): + """ + Returns a default value for given type + """ raise NotImplementedError() def example_value_for_test(self, version): @@ -241,6 +310,11 @@ def is_printable(self): class Array(TypeSpecification): + """ + Represents array complex type + To use instance of this type, it is necessary to declare structures required by self.underlying + (e.g. to use Array, we need to have `struct Foo {...}`) + """ def __init__(self, underlying): self.underlying = underlying @@ -266,6 +340,9 @@ def is_printable(self): class Primitive(TypeSpecification): + """ + Represents a Kafka primitive value + """ PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] @@ -342,6 +419,10 @@ def is_printable(self): class Complex(TypeSpecification): + """ + Represents a complex type (multiple types aggregated into one) + This type gets mapped to C++ struct + """ def __init__(self, name, fields, versions): self.name = name @@ -351,6 +432,10 @@ def __init__(self, name, fields, versions): self.attributes = {} def __compute_declaration_chain(self): + """ + Computes all dependendencies, what means all non-primitive types used by this type + They need to be declared before this struct is declared + """ result = [] for field in self.fields: result.extend(field.type.declaration_chain) @@ -365,8 +450,10 @@ def get_extra(self, key): return self.attributes[key] def compute_constructors(self): - # field lists for different versions may not differ (as Kafka can bump version without any changes) - # but constructors need to be unique + """ + Field lists for different versions may not differ (as Kafka can bump version without any changes) + But constructors need to be unique, so we need to remove duplicates if the signatures match + """ signature_to_constructor = {} for field_list in self.compute_field_lists(): signature = field_list.constructor_signature() @@ -386,6 +473,9 @@ def compute_constructors(self): return sorted(signature_to_constructor.values(), key=lambda x: x['versions'][0]) def compute_field_lists(self): + """ + Return field lists representing each of structure versions + """ field_lists = [] for version in self.versions: field_list = FieldList(version, self.fields) @@ -408,6 +498,9 @@ def is_printable(self): class RenderingHelper: + """ + Helper for jinja templates + """ @staticmethod def get_template(template): diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index df08eec07a074..6ec0bfd7c6f6d 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -1,5 +1,7 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off +{# + Template for 'kafka_request_resolver.cc' + Defines default Kafka request resolver, that uses request parsers in (also generated) 'requests.h' +#} #include "extensions/filters/network/kafka/requests.h" #include "extensions/filters/network/kafka/kafka_request_parser.h" #include "extensions/filters/network/kafka/parser.h" @@ -9,8 +11,13 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -const RequestParserResolver RequestParserResolver::INSTANCE; - +/** + * Creates a parser that corresponds to provided key and version + * If corresponding parser cannot be found (what means a newer version of Kafka protocol), a sentinel parser is returned + * @param api_key Kafka request key + * @param api_version Kafka request's version + * @param context parse context + */ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, RequestContextSharedPtr context) const { @@ -25,4 +32,3 @@ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api } // namespace NetworkFilters } // namespace Extensions } // namespace Envoy -// clang-format on diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 index 8f6d655e1fa31..0708d19402ab4 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -1,3 +1,11 @@ +{# + Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest, FetchRequest, ListOffsetsRequest etc.) + Rendered templates for each request in Kafka protocol will be put into 'requests.h' file + + This template handles binding the top-level structure deserializer (e.g. ProduceRequestV0Deserializer) with RequestParser + These parsers are then used by RequestParserResolver instance depending on received Kafka api key & api version (see 'kafka_request_resolver_cc.j2') +#} + {% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { public: {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 index d6f454369c9de..8dc803f49bfc2 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 @@ -1,5 +1,25 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off +{# + Main template for 'requests.h' file + Gets filled in (by 'contents') with Kafka request structures, deserializers, and parsers + + For each request we have the following: + - 1 top-level structure corresponding to the request (e.g. `struct FetchRequest`) + - N deserializers for top-level structure, one for each request version + - N parsers binding each deserializer with parser + - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that compose into top-level structure + - deserializers for each child structure (M = number of versions where structure is actually used) + + So for example, for FetchRequest we have: + - struct FetchRequest + - FetchRequestV0Deserializer, FetchRequestV1Deserializer, FetchRequestV2Deserializer, etc. + - FetchRequestV0Parser, FetchRequestV1Parser, FetchRequestV2Parser, etc. + - struct FetchRequestTopic + - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, FetchRequestTopicV2Deserializer, etc. + (because topic data is present in every FetchRequest version) + - struct FetchRequestPartition + - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, FetchRequestPartitionV2Deserializer, etc. + (because partition data is present in every FetchRequestTopic version) +#} #pragma once #include "extensions/filters/network/kafka/kafka_request.h" #include "extensions/filters/network/kafka/kafka_request_parser.h" @@ -12,4 +32,3 @@ namespace Kafka { {{ contents }} }}}} -// clang-format on diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 45da13b677f3d..3a6d82e823245 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -1,5 +1,8 @@ -// DO NOT EDIT - THIS FILE WAS GENERATED -// clang-format off +{# + Template for request serialization/deserialization tests + For every request, we want to check if it can be serialized and deserialized properly +#} + #include "extensions/filters/network/kafka/requests.h" #include "extensions/filters/network/kafka/request_codec.h" @@ -25,12 +28,17 @@ public: MOCK_METHOD1(onMessage, void(MessageSharedPtr)); }; +/** + * Helper method + * Takes an instance of a request, serializes it, then deserializes it + * This method gets executed for every request * version pair + */ template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { MessageEncoderImpl serializer{buffer_}; serializer.encode(request); std::shared_ptr mock_listener = std::make_shared(); - RequestDecoder testee{RequestParserResolver::INSTANCE, {mock_listener}}; + RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; MessageSharedPtr receivedMessage; EXPECT_CALL(*mock_listener, onMessage(testing::_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); @@ -39,6 +47,11 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria return std::dynamic_pointer_cast(receivedMessage); }; + +{# + Concrete tests for each request_type and version (field_list) + Each request is naively constructed using some default values (put "string" as std::string, 32 as uint32_t, etc.) +#} {% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} TEST_F(RequestDecoderTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { // given @@ -58,4 +71,3 @@ TEST_F(RequestDecoderTest, shouldParse{{ request_type.name }}V{{ field_list.vers } // namespace NetworkFilters } // namespace Extensions } // namespace Envoy -// clang-format on diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index e23141c173ebe..2feaf4fb226b1 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -3,6 +3,8 @@ #include "common/buffer/buffer_impl.h" #include "common/common/stack_array.h" +#include "absl/strings/string_view.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -29,10 +31,11 @@ void RequestDecoder::onData(Buffer::Instance& data) { * --- replace parser with new start parser, as we are going to parse another request */ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice) { - const char* buffer = reinterpret_cast(slice.mem_); - uint64_t remaining = slice.len_; - while (remaining) { - ParseResponse result = parser->parse(buffer, remaining); + const char* bytes = reinterpret_cast(slice.mem_); + absl::string_view data = {bytes, slice.len_}; + + while (!data.empty()) { + ParseResponse result = parser->parse(data); // this loop guarantees that parsers consuming 0 bytes also get processed while (result.hasData()) { if (!result.next_parser_) { @@ -48,14 +51,14 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl } else { parser = result.next_parser_; } - result = parser->parse(buffer, remaining); + result = parser->parse(data); } } } void MessageEncoderImpl::encode(const Message& message) { Buffer::OwnedImpl data_buffer; - // TODO precompute the size instead of using temporary + // TODO(adamkotwasinski) precompute the size instead of using temporary // also, when we have 'computeSize' method, then we can push encoding request's size into // Request::encode int32_t data_len = message.encode(data_buffer); // encode data computing data length diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 8f3fcd7e865ca..c1d35fe357ab1 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -14,6 +14,8 @@ #include "extensions/filters/network/kafka/kafka_types.h" +#include "absl/strings/string_view.h" + namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -33,12 +35,14 @@ template class Deserializer { /** * Submit data to be processed, will consume as much data as it is necessary. + * If any bytes are consumed, then the provided string view is updated by stepping over consumed + * bytes. * Invoking this method when deserializer is ready has no effect (consumes 0 bytes) - * @param buffer data pointer, will be updated if data is consumed - * @param remaining remaining data in buffer, will be updated if data is consumed - * @return bytes consumed + * + * @param data bytes to be processed, will be updated if any have been consumed + * @return number of bytes consumed (equal to change in 'data') */ - virtual size_t feed(const char*& buffer, uint64_t& remaining) PURE; + virtual size_t feed(absl::string_view& data) PURE; /** * Whether deserializer has consumed enough data to return result @@ -60,17 +64,16 @@ template class IntDeserializer : public Deserializer { public: IntDeserializer() : written_{0}, ready_(false){}; - size_t feed(const char*& buffer, uint64_t& remaining) override { - const size_t available = std::min(sizeof(buf_) - written_, remaining); - memcpy(buf_ + written_, buffer, available); + size_t feed(absl::string_view& data) override { + const size_t available = std::min(sizeof(buf_) - written_, data.size()); + memcpy(buf_ + written_, data.data(), available); written_ += available; if (written_ == sizeof(buf_)) { ready_ = true; } - buffer += available; - remaining -= available; + data = {data.data() + available, data.size() - available}; return available; } @@ -157,9 +160,7 @@ class BooleanDeserializer : public Deserializer { public: BooleanDeserializer(){}; - size_t feed(const char*& buffer, uint64_t& remaining) override { - return buffer_.feed(buffer, remaining); - } + size_t feed(absl::string_view& data) override { return buffer_.feed(data); } bool ready() const override { return buffer_.ready(); } @@ -183,8 +184,8 @@ class StringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) override { - const size_t length_consumed = length_buf_.feed(buffer, remaining); + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer return length_consumed; @@ -200,13 +201,12 @@ class StringDeserializer : public Deserializer { length_consumed_ = true; } - const size_t data_consumed = std::min(required_, remaining); + const size_t data_consumed = std::min(required_, data.size()); const size_t written = data_buf_.size() - required_; - memcpy(data_buf_.data() + written, buffer, data_consumed); + memcpy(data_buf_.data() + written, data.data(), data_consumed); required_ -= data_consumed; - buffer += data_consumed; - remaining -= data_consumed; + data = {data.data() + data_consumed, data.size() - data_consumed}; if (required_ == 0) { ready_ = true; @@ -245,8 +245,8 @@ class NullableStringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) override { - const size_t length_consumed = length_buf_.feed(buffer, remaining); + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer return length_consumed; @@ -272,13 +272,12 @@ class NullableStringDeserializer : public Deserializer { return length_consumed; } - const size_t data_consumed = std::min(required_, remaining); + const size_t data_consumed = std::min(required_, data.size()); const size_t written = data_buf_.size() - required_; - memcpy(data_buf_.data() + written, buffer, data_consumed); + memcpy(data_buf_.data() + written, data.data(), data_consumed); required_ -= data_consumed; - buffer += data_consumed; - remaining -= data_consumed; + data = {data.data() + data_consumed, data.size() - data_consumed}; if (required_ == 0) { ready_ = true; @@ -318,8 +317,8 @@ class BytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) override { - const size_t length_consumed = length_buf_.feed(buffer, remaining); + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer return length_consumed; @@ -335,13 +334,12 @@ class BytesDeserializer : public Deserializer { length_consumed_ = true; } - const size_t data_consumed = std::min(required_, remaining); + const size_t data_consumed = std::min(required_, data.size()); const size_t written = data_buf_.size() - required_; - memcpy(data_buf_.data() + written, buffer, data_consumed); + memcpy(data_buf_.data() + written, data.data(), data_consumed); required_ -= data_consumed; - buffer += data_consumed; - remaining -= data_consumed; + data = {data.data() + data_consumed, data.size() - data_consumed}; if (required_ == 0) { ready_ = true; @@ -378,8 +376,8 @@ class NullableBytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid */ - size_t feed(const char*& buffer, uint64_t& remaining) override { - const size_t length_consumed = length_buf_.feed(buffer, remaining); + size_t feed(absl::string_view& data) override { + const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer return length_consumed; @@ -405,13 +403,12 @@ class NullableBytesDeserializer : public Deserializer { return length_consumed; } - const size_t data_consumed = std::min(required_, remaining); + const size_t data_consumed = std::min(required_, data.size()); const size_t written = data_buf_.size() - required_; - memcpy(data_buf_.data() + written, buffer, data_consumed); + memcpy(data_buf_.data() + written, data.data(), data_consumed); required_ -= data_consumed; - buffer += data_consumed; - remaining -= data_consumed; + data = {data.data() + data_consumed, data.size() - data_consumed}; if (required_ == 0) { ready_ = true; @@ -461,9 +458,9 @@ class ArrayDeserializer : public Deserializer> { /** * Can throw EnvoyException if array length is invalid or if DeserializerType can throw */ - size_t feed(const char*& buffer, uint64_t& remaining) override { + size_t feed(absl::string_view& data) override { - const size_t length_consumed = length_buf_.feed(buffer, remaining); + const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { // break early: we still need to fill in length buffer return length_consumed; @@ -485,7 +482,7 @@ class ArrayDeserializer : public Deserializer> { size_t child_consumed{0}; for (DeserializerType& child : children_) { - child_consumed += child.feed(buffer, remaining); + child_consumed += child.feed(data); } bool children_ready_ = true; @@ -538,9 +535,9 @@ class NullableArrayDeserializer : public Deserializer +#include +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/exception.h" +#include "envoy/common/pure.h" + +#include "common/common/byte_order.h" +#include "common/common/fmt.h" + +#include "extensions/filters/network/kafka/kafka_types.h" +#include "extensions/filters/network/kafka/serialization.h" + +#include "absl/strings/string_view.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * This header contains only composite deserializers + * The basic design is composite deserializer creating delegates DeserializerType1..N + * Result of type ResponseType is constructed by getting results of each of delegates + * These deserializers can throw, if any of the delegate deserializers can + */ + +/** + * Composite deserializer that uses 0 deserializer(s) (corner case) + * Always ready, as it consumes no bytes + * Creates a result value using a no-arg constructor + * @param ResponseType type of deserialized data + */ +template +class CompositeDeserializerWith0Delegates : public Deserializer { +public: + CompositeDeserializerWith0Delegates(){}; + size_t feed(absl::string_view&) override { return 0; } + bool ready() const override { return true; } + ResponseType get() const override { return {}; } +}; + +{% for field_count in counts %} +/** + * Composite deserializer that uses {{ field_count }} deserializer(s) + * Passes data to each of the underlying deserializers + * (deserializers that are already ready do not consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready + * (which means all deserializers before it are ready too) + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * + * @param ResponseType type of deserialized data +{% for field in range(1, field_count + 1) %} * @param DeserializerType{{ field }} deserializer {{ field }} (result used as argument {{ field }} of ResponseType's ctor) +{% endfor %} */ +template < + typename ResponseType{% for field in range(1, field_count + 1) %}, typename DeserializerType{{ field }}{% endfor %} +> +class CompositeDeserializerWith{{ field_count }}Delegates : public Deserializer { +public: + CompositeDeserializerWith{{ field_count }}Delegates(){}; + + size_t feed(absl::string_view& data) override { + size_t consumed = 0; + {% for field in range(1, field_count + 1) %} + consumed += delegate{{ field }}_.feed(data); + {% endfor %} + return consumed; + } + + bool ready() const override { return delegate{{ field_count }}_.ready(); } + + ResponseType get() const override { + return { + {% for field in range(1, field_count + 1) %}delegate{{ field }}_.get(), + {% endfor %}}; + } + +protected: + {% for field in range(1, field_count + 1) %} + DeserializerType{{ field }} delegate{{ field }}_; + {% endfor %} +}; +{% endfor %} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 new file mode 100644 index 0000000000000..4708b508da37c --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 @@ -0,0 +1,87 @@ +{# + Creates 'serialization_composite_test.cc' + + Template for composite serializer tests (the CompositeDeserializerWith_N_Delegates classes) + Covers the corner case of 0 delegates, and then uses templating to create tests for 1..N cases +#} + +#include "extensions/filters/network/kafka/serialization_composite.h" + +#include "test/extensions/filters/network/kafka/serialization_utilities.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Tests in this class are supposed to check whether serialization operations on composite deserializers are behaving correctly + */ + +// tests for composite deserializer with 0 fields (corner case) + +struct CompositeResultWith0Fields { + size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } + bool operator==(const CompositeResultWith0Fields&) const { return true; } +}; + +typedef CompositeDeserializerWith0Delegates TestCompositeDeserializer0; + +// composite with 0 delegates is special case: it's always ready +TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { + // given + const TestCompositeDeserializer0 testee{}; + // when, then + ASSERT_EQ(testee.ready(), true); +} + +TEST(CompositeDeserializerWith0Delegates, ShouldDeserialize) { + const CompositeResultWith0Fields expected{}; + serializeThenDeserializeAndCheckEquality(expected); +} + +// tests for composite deserializer with N+ fields + +{% for field_count in counts %} +struct CompositeResultWith{{ field_count }}Fields { + {% for field in range(1, field_count + 1) %} + const std::string field{{ field }}_; + {% endfor %} + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + {% for field in range(1, field_count + 1) %} + written += encoder.encode(field{{ field }}_, dst); + {% endfor %} + return written; + } + + bool operator==(const CompositeResultWith{{ field_count }}Fields& rhs) const { + return true{% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_{% endfor %}; + } +}; + +typedef CompositeDeserializerWith{{ field_count }}Delegates< + CompositeResultWith{{ field_count }}Fields + {% for field in range(1, field_count + 1) %}, StringDeserializer{% endfor %} +> TestCompositeDeserializer{{ field_count }}; + +TEST(CompositeDeserializerWith{{ field_count }}Delegates, EmptyBufferShouldNotBeReady) { + // given + const TestCompositeDeserializer{{ field_count }} testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); +} + +TEST(CompositeDeserializerWith{{ field_count }}Delegates, ShouldDeserialize) { + const CompositeResultWith{{ field_count }}Fields expected{ + {% for field in range(1, field_count + 1) %}"s{{ field }}", {% endfor %} + }; + serializeThenDeserializeAndCheckEquality(expected); +} +{% endfor %} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization_composite.h b/source/extensions/filters/network/kafka/serialization_composite.h deleted file mode 100644 index 15e4c5cd7caa0..0000000000000 --- a/source/extensions/filters/network/kafka/serialization_composite.h +++ /dev/null @@ -1,468 +0,0 @@ -// XXX this file can be generated, as it's repeating the same code for 0..9 delegates -#pragma once - -#include -#include -#include -#include - -#include "envoy/buffer/buffer.h" -#include "envoy/common/exception.h" -#include "envoy/common/pure.h" - -#include "common/common/byte_order.h" -#include "common/common/fmt.h" - -#include "extensions/filters/network/kafka/kafka_types.h" -#include "extensions/filters/network/kafka/serialization.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * This header contains only composite deserializers - * The basic design is composite deserializer creating delegates DeserializerType1..N - * Result of type ResponseType is constructed by getting results of each of delegates - * These deserializers can throw, if any of the delegate deserializers can - */ - -/** - * Composite deserializer that uses 0 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - */ -template -class CompositeDeserializerWith0Delegates : public Deserializer { -public: - CompositeDeserializerWith0Delegates(){}; - - size_t feed(const char*&, uint64_t&) override { return 0; } - - bool ready() const override { return true; } - - ResponseType get() const override { return {}; } - -protected: -}; - -/** - * Composite deserializer that uses 1 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith1Delegates : public Deserializer { -public: - CompositeDeserializerWith1Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate1_.ready(); } - - ResponseType get() const override { return {delegate1_.get()}; } - -protected: - DeserializerType1 delegate1_; -}; - -/** - * Composite deserializer that uses 2 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith2Delegates : public Deserializer { -public: - CompositeDeserializerWith2Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate2_.ready(); } - - ResponseType get() const override { return {delegate1_.get(), delegate2_.get()}; } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; -}; - -/** - * Composite deserializer that uses 3 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith3Delegates : public Deserializer { -public: - CompositeDeserializerWith3Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate3_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; -}; - -/** - * Composite deserializer that uses 4 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith4Delegates : public Deserializer { -public: - CompositeDeserializerWith4Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate4_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; -}; - -/** - * Composite deserializer that uses 5 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith5Delegates : public Deserializer { -public: - CompositeDeserializerWith5Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - consumed += delegate5_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate5_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), - delegate5_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; - DeserializerType5 delegate5_; -}; - -/** - * Composite deserializer that uses 6 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) - * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith6Delegates : public Deserializer { -public: - CompositeDeserializerWith6Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - consumed += delegate5_.feed(buffer, remaining); - consumed += delegate6_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate6_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), - delegate4_.get(), delegate5_.get(), delegate6_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; - DeserializerType5 delegate5_; - DeserializerType6 delegate6_; -}; - -/** - * Composite deserializer that uses 7 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) - * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) - * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith7Delegates : public Deserializer { -public: - CompositeDeserializerWith7Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - consumed += delegate5_.feed(buffer, remaining); - consumed += delegate6_.feed(buffer, remaining); - consumed += delegate7_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate7_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), - delegate5_.get(), delegate6_.get(), delegate7_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; - DeserializerType5 delegate5_; - DeserializerType6 delegate6_; - DeserializerType7 delegate7_; -}; - -/** - * Composite deserializer that uses 8 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) - * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) - * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) - * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith8Delegates : public Deserializer { -public: - CompositeDeserializerWith8Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - consumed += delegate5_.feed(buffer, remaining); - consumed += delegate6_.feed(buffer, remaining); - consumed += delegate7_.feed(buffer, remaining); - consumed += delegate8_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate8_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), delegate4_.get(), - delegate5_.get(), delegate6_.get(), delegate7_.get(), delegate8_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; - DeserializerType5 delegate5_; - DeserializerType6 delegate6_; - DeserializerType7 delegate7_; - DeserializerType8 delegate8_; -}; - -/** - * Composite deserializer that uses 9 deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } - * - * @param ResponseType type of deserialized data - * @param DeserializerType1 deserializer 1 (result used as argument 1 of ResponseType's ctor) - * @param DeserializerType2 deserializer 2 (result used as argument 2 of ResponseType's ctor) - * @param DeserializerType3 deserializer 3 (result used as argument 3 of ResponseType's ctor) - * @param DeserializerType4 deserializer 4 (result used as argument 4 of ResponseType's ctor) - * @param DeserializerType5 deserializer 5 (result used as argument 5 of ResponseType's ctor) - * @param DeserializerType6 deserializer 6 (result used as argument 6 of ResponseType's ctor) - * @param DeserializerType7 deserializer 7 (result used as argument 7 of ResponseType's ctor) - * @param DeserializerType8 deserializer 8 (result used as argument 8 of ResponseType's ctor) - * @param DeserializerType9 deserializer 9 (result used as argument 9 of ResponseType's ctor) - */ -template -class CompositeDeserializerWith9Delegates : public Deserializer { -public: - CompositeDeserializerWith9Delegates(){}; - - size_t feed(const char*& buffer, uint64_t& remaining) override { - size_t consumed = 0; - consumed += delegate1_.feed(buffer, remaining); - consumed += delegate2_.feed(buffer, remaining); - consumed += delegate3_.feed(buffer, remaining); - consumed += delegate4_.feed(buffer, remaining); - consumed += delegate5_.feed(buffer, remaining); - consumed += delegate6_.feed(buffer, remaining); - consumed += delegate7_.feed(buffer, remaining); - consumed += delegate8_.feed(buffer, remaining); - consumed += delegate9_.feed(buffer, remaining); - return consumed; - } - - bool ready() const override { return delegate9_.ready(); } - - ResponseType get() const override { - return {delegate1_.get(), delegate2_.get(), delegate3_.get(), - delegate4_.get(), delegate5_.get(), delegate6_.get(), - delegate7_.get(), delegate8_.get(), delegate9_.get()}; - } - -protected: - DeserializerType1 delegate1_; - DeserializerType2 delegate2_; - DeserializerType3 delegate3_; - DeserializerType4 delegate4_; - DeserializerType5 delegate5_; - DeserializerType6 delegate6_; - DeserializerType7 delegate7_; - DeserializerType8 delegate8_; - DeserializerType9 delegate9_; -}; - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index f767e0a2b24c8..6b9af10d0b776 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -2,6 +2,7 @@ licenses(["notice"]) # Apache 2 load( "//bazel:envoy_build_system.bzl", + "envoy_cc_test_library", "envoy_package", ) load( @@ -11,11 +12,20 @@ load( envoy_package() +envoy_cc_test_library( + name = "serialization_utilities_lib", + hdrs = ["serialization_utilities.h"], + deps = [ + "//source/extensions/filters/network/kafka:serialization_lib", + ], +) + envoy_extension_cc_test( name = "serialization_test", srcs = ["serialization_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ + ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:serialization_lib", "//test/mocks/server:server_mocks", ], @@ -26,16 +36,31 @@ envoy_extension_cc_test( srcs = ["serialization_composite_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ + ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:serialization_lib", "//test/mocks/server:server_mocks", ], ) +genrule( + name = "serialization_composite_test_generator", + srcs = [], + outs = ["serialization_composite_test.cc"], + cmd = """ + ./$(location //source/extensions/filters/network/kafka:serialization_composite_generator) generate-test \ + $(location serialization_composite_test.cc) + """, + tools = [ + "//source/extensions/filters/network/kafka:serialization_composite_generator", + ], +) + envoy_extension_cc_test( name = "kafka_request_parser_test", srcs = ["kafka_request_parser_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ + ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:kafka_request_lib", "//test/mocks/server:server_mocks", ], @@ -62,13 +87,17 @@ envoy_extension_cc_test( ) genrule( - name = "kafka_generated_test", + name = "requests_test_generator", srcs = [ "@kafka_produce_request_spec//file", "@kafka_fetch_request_spec//file", ], outs = ["requests_test.cc"], - cmd = "./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test $(location requests_test.cc) $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file)", + cmd = """ + ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ + $(location requests_test.cc) \ + $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file) + """, tools = [ "//source/extensions/filters/network/kafka:kafka_code_generator", ], diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index f94e0ac3cbb45..8f1b9365a0522 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -1,11 +1,9 @@ -#include "common/common/stack_array.h" - #include "extensions/filters/network/kafka/kafka_request_parser.h" +#include "test/extensions/filters/network/kafka/serialization_utilities.h" #include "test/mocks/server/mocks.h" #include "gmock/gmock.h" -#include "gtest/gtest.h" using testing::_; using testing::Return; @@ -47,22 +45,23 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { int32_t request_len = 1234; encoder_.encode(request_len, buffer()); - const char* bytes = getBytes(); - uint64_t remaining = 1024; + const absl::string_view orig_data = {getBytes(), 1024}; + absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(bytes, remaining); + const ParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); + assertStringViewIncrement(data, orig_data, sizeof(int32_t)); } class MockParser : public Parser { public: - ParseResponse parse(const char*&, uint64_t&) override { + ParseResponse parse(absl::string_view&) override { throw new EnvoyException("should not be invoked"); } }; @@ -82,29 +81,28 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext const int16_t api_version{2}; const int32_t correlation_id{10}; const NullableString client_id{"aaa"}; - size_t written = 0; - written += encoder_.encode(api_key, buffer()); - written += encoder_.encode(api_version, buffer()); - written += encoder_.encode(correlation_id, buffer()); - written += encoder_.encode(client_id, buffer()); + size_t header_len = 0; + header_len += encoder_.encode(api_key, buffer()); + header_len += encoder_.encode(api_version, buffer()); + header_len += encoder_.encode(correlation_id, buffer()); + header_len += encoder_.encode(client_id, buffer()); - const char* bytes = getBytes(); - uint64_t remaining = 100000; - const uint64_t orig_remaining = remaining; + const absl::string_view orig_data = {getBytes(), 100000}; + absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(bytes, remaining); + const ParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_EQ(result.next_parser_, parser); ASSERT_EQ(result.message_, nullptr); - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - written); - ASSERT_EQ(remaining, orig_remaining - written); - const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); + ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len - header_len); + + assertStringViewIncrement(data, orig_data, header_len); } TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDuringFeeding) { @@ -113,10 +111,9 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur // throws during feeding class ThrowingRequestHeaderDeserializer : public RequestHeaderDeserializer { public: - size_t feed(const char*& buffer, uint64_t& remaining) override { + size_t feed(absl::string_view& data) override { // move some pointers to simulate data consumption - buffer += FAILED_DESERIALIZER_STEP; - remaining -= FAILED_DESERIALIZER_STEP; + data = {data.data() + FAILED_DESERIALIZER_STEP, data.size() - FAILED_DESERIALIZER_STEP}; throw EnvoyException("feed"); }; @@ -134,38 +131,30 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur RequestHeaderParser testee{parser_resolver, request_context, std::make_unique()}; - const char* bytes = getBytes(); - const char* orig_bytes = bytes; - uint64_t remaining = 100000; - const uint64_t orig_remaining = remaining; + const absl::string_view orig_data = {getBytes(), 100000}; + absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(bytes, remaining); + const ParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); - ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); - ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_size - FAILED_DESERIALIZER_STEP); + + assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); } TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { // given - - const int32_t move = FAILED_DESERIALIZER_STEP; - // throws during feeding class ThrowingDeserializer : public Deserializer { public: - size_t feed(const char*& buffer, uint64_t& remaining) override { + size_t feed(absl::string_view&) override { // move some pointers to simulate data consumption - buffer += move; - remaining -= move; throw EnvoyException("feed"); }; @@ -174,37 +163,28 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFee int32_t get() const override { throw std::runtime_error("should not be invoked at all"); }; }; - const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request - RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; - + RequestContextSharedPtr request_context{new RequestContext{1024, {}}}; RequestParser testee{request_context}; - const char* bytes = getBytes(); - const char* orig_bytes = bytes; - uint64_t remaining = 100000; - const uint64_t orig_remaining = remaining; + absl::string_view data = {getBytes(), 100000}; // when - const ParseResponse result = testee.parse(bytes, remaining); + bool caught = false; + try { + testee.parse(data); + } catch (EnvoyException& e) { + caught = true; + } // then - ASSERT_EQ(result.hasData(), true); - ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); - ASSERT_EQ(result.message_, nullptr); - - ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); - ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, - request_size - FAILED_DESERIALIZER_STEP); + ASSERT_EQ(caught, true); } // deserializer that consumes FAILED_DESERIALIZER_STEP bytes and returns 0 class SomeBytesDeserializer : public Deserializer { public: - size_t feed(const char*& buffer, uint64_t& remaining) override { - buffer += FAILED_DESERIALIZER_STEP; - remaining -= FAILED_DESERIALIZER_STEP; + size_t feed(absl::string_view& data) override { + data = {data.data() + FAILED_DESERIALIZER_STEP, data.size() - FAILED_DESERIALIZER_STEP}; return FAILED_DESERIALIZER_STEP; }; @@ -220,24 +200,21 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyBut RequestParser testee{request_context}; - const char* bytes = getBytes(); - const char* orig_bytes = bytes; - uint64_t remaining = 100000; - const uint64_t orig_remaining = remaining; + const absl::string_view orig_data = {getBytes(), 100000}; + absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(bytes, remaining); + const ParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); - ASSERT_EQ(bytes, orig_bytes + FAILED_DESERIALIZER_STEP); - ASSERT_EQ(remaining, orig_remaining - FAILED_DESERIALIZER_STEP); - ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_size - FAILED_DESERIALIZER_STEP); + + assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); } TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { @@ -250,12 +227,11 @@ TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { const Bytes garbage(request_len * 2); encoder_.encode(garbage, buffer()); - const char* bytes = getBytes(); - uint64_t remaining = request_len * 2; - const uint64_t orig_remaining = remaining; + const absl::string_view orig_data = {getBytes(), request_len * 2}; + absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(bytes, remaining); + const ParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); @@ -263,7 +239,8 @@ TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); - ASSERT_EQ(remaining, orig_remaining - request_len); + + assertStringViewIncrement(data, orig_data, request_len); } } // namespace Kafka diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc index d273b6b1b20ab..4cab50a825d12 100644 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_test.cc @@ -35,7 +35,7 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria serializer.encode(request); std::shared_ptr mock_listener = std::make_shared(); - RequestDecoder testee{RequestParserResolver::INSTANCE, {mock_listener}}; + RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; MessageSharedPtr receivedMessage; EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); diff --git a/test/extensions/filters/network/kafka/serialization_composite_test.cc b/test/extensions/filters/network/kafka/serialization_composite_test.cc deleted file mode 100644 index c3ff041a5e62a..0000000000000 --- a/test/extensions/filters/network/kafka/serialization_composite_test.cc +++ /dev/null @@ -1,500 +0,0 @@ -// XXX this file can be generated, as it's repeating the same code for 0..9 delegates -#include "common/common/stack_array.h" - -#include "extensions/filters/network/kafka/serialization.h" -#include "extensions/filters/network/kafka/serialization_composite.h" - -#include "test/mocks/server/mocks.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * Tests in this class are supposed to check whether serialization operations - * on composite deserializers are behaving correctly - */ - -// helper function -const char* getRawData(const Buffer::OwnedImpl& buffer) { - uint64_t num_slices = buffer.getRawSlices(nullptr, 0); - STACK_ARRAY(slices, Buffer::RawSlice, num_slices); - buffer.getRawSlices(slices.begin(), num_slices); - return reinterpret_cast((slices[0]).mem_); -} - -// exactly what is says on the tin: -// 1. serialize expected using Encoder -// 2. deserialize byte array using testee deserializer -// 3. verify result = expected -// 4. verify that data pointer moved correct amount -// 5. feed testee more data -// 6. verify that nothing more was consumed -template -void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { - // given - BT testee{}; - - Buffer::OwnedImpl buffer; - EncodingContext encoder{-1}; - const size_t written = encoder.encode(expected, buffer); - - uint64_t remaining = - 10 * - written; // tell parser that there is more data, it should never consume more than written - const uint64_t orig_remaining = remaining; - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - const size_t consumed = testee.feed(data, remaining); - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), expected); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); - - // when - 2 - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); -} - -// does the same thing as the above test, -// but instead of providing whole data at one, it provides it in N one-byte chunks -// this verifies if deserializer keeps state properly (no overwrites etc.) -template -void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { - // given - BT testee{}; - - Buffer::OwnedImpl buffer; - EncodingContext encoder{-1}; - const size_t written = encoder.encode(expected, buffer); - - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - size_t consumed = 0; - for (size_t i = 0; i < written; ++i) { - uint64_t data_size = 1; - consumed += testee.feed(data, data_size); - ASSERT_EQ(data_size, 0); - } - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), expected); - ASSERT_EQ(data, orig_data + consumed); - - // when - 2 - uint64_t remaining = 1024; - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, 1024); -} - -// wrapper to run both tests -template void serializeThenDeserializeAndCheckEquality(AT expected) { - serializeThenDeserializeAndCheckEqualityInOneGo(expected); - serializeThenDeserializeAndCheckEqualityWithChunks(expected); -} - -// tests for composite deserializers - -struct CompositeResultWith0Fields { - - size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } - - bool operator==(const CompositeResultWith0Fields&) const { return true; } -}; - -typedef CompositeDeserializerWith0Delegates TestCompositeDeserializer0; - -// composite with 0 delegates is special case: it's always ready -TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { - // given - const TestCompositeDeserializer0 testee{}; - // when, then - ASSERT_EQ(testee.ready(), true); -} - -TEST(CompositeDeserializerWith0Delegates, ShouldDeserialize) { - const CompositeResultWith0Fields expected{}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith1Fields { - const std::string field1_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - return written; - } - - bool operator==(const CompositeResultWith1Fields& rhs) const { return field1_ == rhs.field1_; } -}; - -typedef CompositeDeserializerWith1Delegates - TestCompositeDeserializer1; - -TEST(CompositeDeserializerWith1Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer1 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith1Delegates, ShouldDeserialize) { - const CompositeResultWith1Fields expected{"s1"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith2Fields { - const std::string field1_; - const std::string field2_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - return written; - } - - bool operator==(const CompositeResultWith2Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_; - } -}; - -typedef CompositeDeserializerWith2Delegates - TestCompositeDeserializer2; - -TEST(CompositeDeserializerWith2Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer2 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith2Delegates, ShouldDeserialize) { - const CompositeResultWith2Fields expected{"s1", "s2"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith3Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - return written; - } - - bool operator==(const CompositeResultWith3Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_; - } -}; - -typedef CompositeDeserializerWith3Delegates - TestCompositeDeserializer3; - -TEST(CompositeDeserializerWith3Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer3 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith3Delegates, ShouldDeserialize) { - const CompositeResultWith3Fields expected{"s1", "s2", "s3"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith4Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - return written; - } - - bool operator==(const CompositeResultWith4Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_; - } -}; - -typedef CompositeDeserializerWith4Delegates - TestCompositeDeserializer4; - -TEST(CompositeDeserializerWith4Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer4 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith4Delegates, ShouldDeserialize) { - const CompositeResultWith4Fields expected{"s1", "s2", "s3", "s4"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith5Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - const std::string field5_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - written += encoder.encode(field5_, dst); - return written; - } - - bool operator==(const CompositeResultWith5Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_ && field5_ == rhs.field5_; - } -}; - -typedef CompositeDeserializerWith5Delegates - TestCompositeDeserializer5; - -TEST(CompositeDeserializerWith5Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer5 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith5Delegates, ShouldDeserialize) { - const CompositeResultWith5Fields expected{"s1", "s2", "s3", "s4", "s5"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith6Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - const std::string field5_; - const std::string field6_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - written += encoder.encode(field5_, dst); - written += encoder.encode(field6_, dst); - return written; - } - - bool operator==(const CompositeResultWith6Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_; - } -}; - -typedef CompositeDeserializerWith6Delegates< - CompositeResultWith6Fields, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer, StringDeserializer, StringDeserializer> - TestCompositeDeserializer6; - -TEST(CompositeDeserializerWith6Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer6 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith6Delegates, ShouldDeserialize) { - const CompositeResultWith6Fields expected{"s1", "s2", "s3", "s4", "s5", "s6"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith7Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - const std::string field5_; - const std::string field6_; - const std::string field7_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - written += encoder.encode(field5_, dst); - written += encoder.encode(field6_, dst); - written += encoder.encode(field7_, dst); - return written; - } - - bool operator==(const CompositeResultWith7Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && - field7_ == rhs.field7_; - } -}; - -typedef CompositeDeserializerWith7Delegates< - CompositeResultWith7Fields, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer> - TestCompositeDeserializer7; - -TEST(CompositeDeserializerWith7Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer7 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith7Delegates, ShouldDeserialize) { - const CompositeResultWith7Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith8Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - const std::string field5_; - const std::string field6_; - const std::string field7_; - const std::string field8_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - written += encoder.encode(field5_, dst); - written += encoder.encode(field6_, dst); - written += encoder.encode(field7_, dst); - written += encoder.encode(field8_, dst); - return written; - } - - bool operator==(const CompositeResultWith8Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && - field7_ == rhs.field7_ && field8_ == rhs.field8_; - } -}; - -typedef CompositeDeserializerWith8Delegates< - CompositeResultWith8Fields, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer> - TestCompositeDeserializer8; - -TEST(CompositeDeserializerWith8Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer8 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith8Delegates, ShouldDeserialize) { - const CompositeResultWith8Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -struct CompositeResultWith9Fields { - const std::string field1_; - const std::string field2_; - const std::string field3_; - const std::string field4_; - const std::string field5_; - const std::string field6_; - const std::string field7_; - const std::string field8_; - const std::string field9_; - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, dst); - written += encoder.encode(field2_, dst); - written += encoder.encode(field3_, dst); - written += encoder.encode(field4_, dst); - written += encoder.encode(field5_, dst); - written += encoder.encode(field6_, dst); - written += encoder.encode(field7_, dst); - written += encoder.encode(field8_, dst); - written += encoder.encode(field9_, dst); - return written; - } - - bool operator==(const CompositeResultWith9Fields& rhs) const { - return field1_ == rhs.field1_ && field2_ == rhs.field2_ && field3_ == rhs.field3_ && - field4_ == rhs.field4_ && field5_ == rhs.field5_ && field6_ == rhs.field6_ && - field7_ == rhs.field7_ && field8_ == rhs.field8_ && field9_ == rhs.field9_; - } -}; - -typedef CompositeDeserializerWith9Delegates< - CompositeResultWith9Fields, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer, StringDeserializer, StringDeserializer, StringDeserializer, - StringDeserializer, StringDeserializer> - TestCompositeDeserializer9; - -TEST(CompositeDeserializerWith9Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer9 testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); -} - -TEST(CompositeDeserializerWith9Delegates, ShouldDeserialize) { - const CompositeResultWith9Fields expected{"s1", "s2", "s3", "s4", "s5", "s6", "s7", "s8", "s9"}; - serializeThenDeserializeAndCheckEquality(expected); -} - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 2d43e1d3cbc06..20cd1a4cb71fb 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,12 +1,6 @@ -#include "common/common/stack_array.h" - #include "extensions/filters/network/kafka/serialization.h" -#include "extensions/filters/network/kafka/serialization_composite.h" - -#include "test/mocks/server/mocks.h" -#include "gmock/gmock.h" -#include "gtest/gtest.h" +#include "test/extensions/filters/network/kafka/serialization_utilities.h" namespace Envoy { namespace Extensions { @@ -14,8 +8,8 @@ namespace NetworkFilters { namespace Kafka { /** - * Tests in this class are supposed to check whether serialization operations - * on Kafka-primitive types are behaving correctly + * Tests in this file are supposed to check whether serialization operations + * on Kafka-primitive types (ints, strings, arrays) are behaving correctly */ // freshly created deserializers should not be ready @@ -51,101 +45,6 @@ TEST(NullableArrayDeserializer, EmptyBufferShouldNotBeReady) { ASSERT_EQ(testee.ready(), false); } -EncodingContext encoder{-1}; // api_version does not matter for primitive types - -// helper function -const char* getRawData(const Buffer::OwnedImpl& buffer) { - uint64_t num_slices = buffer.getRawSlices(nullptr, 0); - STACK_ARRAY(slices, Buffer::RawSlice, num_slices); - buffer.getRawSlices(slices.begin(), num_slices); - return reinterpret_cast((slices[0]).mem_); -} - -// exactly what is says on the tin: -// 1. serialize expected using Encoder -// 2. deserialize byte array using testee deserializer -// 3. verify result = expected -// 4. verify that data pointer moved correct amount -// 5. feed testee more data -// 6. verify that nothing more was consumed -template -void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { - // given - BT testee{}; - - Buffer::OwnedImpl buffer; - const size_t written = encoder.encode(expected, buffer); - - uint64_t remaining = - 10 * - written; // tell parser that there is more data, it should never consume more than written - const uint64_t orig_remaining = remaining; - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - const size_t consumed = testee.feed(data, remaining); - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), expected); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); - - // when - 2 - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, orig_remaining - consumed); -} - -// does the same thing as the above test, -// but instead of providing whole data at one, it provides it in N one-byte chunks -// this verifies if deserializer keeps state properly (no overwrites etc.) -template -void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { - // given - BT testee{}; - - Buffer::OwnedImpl buffer; - const size_t written = encoder.encode(expected, buffer); - - const char* data = getRawData(buffer); - const char* orig_data = data; - - // when - size_t consumed = 0; - for (size_t i = 0; i < written; ++i) { - uint64_t data_size = 1; - consumed += testee.feed(data, data_size); - ASSERT_EQ(data_size, 0); - } - - // then - ASSERT_EQ(consumed, written); - ASSERT_EQ(testee.ready(), true); - ASSERT_EQ(testee.get(), expected); - ASSERT_EQ(data, orig_data + consumed); - - // when - 2 - uint64_t remaining = 1024; - const size_t consumed2 = testee.feed(data, remaining); - - // then - 2 (nothing changes) - ASSERT_EQ(consumed2, 0); - ASSERT_EQ(data, orig_data + consumed); - ASSERT_EQ(remaining, 1024); -} - -// wrapper to run both tests -template void serializeThenDeserializeAndCheckEquality(AT expected) { - serializeThenDeserializeAndCheckEqualityInOneGo(expected); - serializeThenDeserializeAndCheckEqualityWithChunks(expected); -} - // extracted test for numeric buffers #define TEST_DeserializerShouldDeserialize(BufferClass, DataClass, Value) \ TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ @@ -161,6 +60,8 @@ TEST_DeserializerShouldDeserialize(UInt32Deserializer, uint32_t, 42); TEST_DeserializerShouldDeserialize(Int64Deserializer, int64_t, 42); TEST_DeserializerShouldDeserialize(BooleanDeserializer, bool, true); +EncodingContext encoder{-1}; // api_version does not matter for primitive types + TEST(StringDeserializer, ShouldDeserialize) { const std::string value = "sometext"; serializeThenDeserializeAndCheckEquality(value); @@ -179,12 +80,11 @@ TEST(StringDeserializer, ShouldThrowOnInvalidLength) { int16_t len = -1; // STRING accepts only >= 0 encoder.encode(len, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } TEST(NullableStringDeserializer, ShouldDeserializeString) { @@ -213,12 +113,11 @@ TEST(NullableStringDeserializer, ShouldThrowOnInvalidLength) { int16_t len = -2; // -1 is OK for NULLABLE_STRING encoder.encode(len, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } TEST(BytesDeserializer, ShouldDeserialize) { @@ -239,12 +138,11 @@ TEST(BytesDeserializer, ShouldThrowOnInvalidLength) { const int32_t bytes_length = -1; // BYTES accepts only >= 0 encoder.encode(bytes_length, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } TEST(NullableBytesDeserializer, ShouldDeserialize) { @@ -270,12 +168,11 @@ TEST(NullableBytesDeserializer, ShouldThrowOnInvalidLength) { const int32_t bytes_length = -2; // -1 is OK for NULLABLE_BYTES encoder.encode(bytes_length, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } TEST(ArrayDeserializer, ShouldConsumeCorrectAmountOfData) { @@ -292,12 +189,11 @@ TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { const int32_t len = -1; // ARRAY accepts only >= 0 encoder.encode(len, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } TEST(NullableArrayDeserializer, ShouldConsumeCorrectAmountOfData) { @@ -314,12 +210,11 @@ TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { const int32_t len = -2; // -1 is OK for ARRAY encoder.encode(len, buffer); - uint64_t remaining = 1024; - const char* data = getRawData(buffer); + absl::string_view data = {getRawData(buffer), 1024}; // when // then - EXPECT_THROW(testee.feed(data, remaining), EnvoyException); + EXPECT_THROW(testee.feed(data), EnvoyException); } } // namespace Kafka diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h new file mode 100644 index 0000000000000..f7674c1548530 --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -0,0 +1,120 @@ +#pragma once + +#include "common/buffer/buffer_impl.h" +#include "common/common/stack_array.h" + +#include "absl/strings/string_view.h" +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +/** + * Verifies that 'incremented' string view is actually 'original' string view, that has incremented + * by 'difference' bytes. + */ +void assertStringViewIncrement(absl::string_view incremented, absl::string_view original, + size_t difference) { + + ASSERT_EQ(incremented.data(), original.data() + difference); + ASSERT_EQ(incremented.size(), original.size() - difference); +} + +// helper function +const char* getRawData(const Buffer::OwnedImpl& buffer) { + uint64_t num_slices = buffer.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); +} + +// exactly what is says on the tin: +// 1. serialize expected using Encoder +// 2. deserialize byte array using testee deserializer +// 3. verify result = expected +// 4. verify that data pointer moved correct amount +// 5. feed testee more data +// 6. verify that nothing more was consumed +template +void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder{-1}; + const size_t written = encoder.encode(expected, buffer); + + // tell parser that there is more data, it should never consume more than written + const absl::string_view orig_data = {getRawData(buffer), 10 * written}; + absl::string_view data = orig_data; + + // when + const size_t consumed = testee.feed(data); + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + assertStringViewIncrement(data, orig_data, consumed); + + // when - 2 + const size_t consumed2 = testee.feed(data); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + assertStringViewIncrement(data, orig_data, consumed); +} + +// does the same thing as the above test, +// but instead of providing whole data at one, it provides it in N one-byte chunks +// this verifies if deserializer keeps state properly (no overwrites etc.) +template +void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { + // given + BT testee{}; + + Buffer::OwnedImpl buffer; + EncodingContext encoder{-1}; + const size_t written = encoder.encode(expected, buffer); + + const absl::string_view orig_data = {getRawData(buffer), written}; + + // when + absl::string_view data = orig_data; + size_t consumed = 0; + for (size_t i = 0; i < written; ++i) { + data = {data.data(), 1}; // consume data byte-by-byte + size_t step = testee.feed(data); + consumed += step; + ASSERT_EQ(step, 1); + ASSERT_EQ(data.size(), 0); + } + + // then + ASSERT_EQ(consumed, written); + ASSERT_EQ(testee.ready(), true); + ASSERT_EQ(testee.get(), expected); + assertStringViewIncrement(data, orig_data, consumed); + + // when - 2 + absl::string_view more_data = {data.data(), 1024}; + const size_t consumed2 = testee.feed(more_data); + + // then - 2 (nothing changes) + ASSERT_EQ(consumed2, 0); + ASSERT_EQ(more_data.data(), orig_data.data() + consumed); + ASSERT_EQ(more_data.size(), 1024); +} + +// wrapper to run both tests +template void serializeThenDeserializeAndCheckEquality(AT expected) { + serializeThenDeserializeAndCheckEqualityInOneGo(expected); + serializeThenDeserializeAndCheckEqualityWithChunks(expected); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy From f369bfa694051e50ad61c518fff6330e9d1b7359 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 19 Mar 2019 12:36:28 -0700 Subject: [PATCH 16/29] Fixes: spelling, clang-tidy, documentation, cleaning up internal API; request codec tests Signed-off-by: Adam Kotwasinski --- .../filters/network/kafka/kafka_request.h | 5 +- .../kafka_generator.py | 27 ++- ...quest_codec_request_integration_test_cc.j2 | 99 +++++++++++ .../filters/network/kafka/request_codec.cc | 37 ++-- .../filters/network/kafka/request_codec.h | 39 ++++- test/extensions/filters/network/kafka/BUILD | 32 +++- .../kafka/request_codec_integration_test.cc | 82 +++++++++ .../network/kafka/request_codec_test.cc | 99 ----------- .../network/kafka/request_codec_unit_test.cc | 165 ++++++++++++++++++ .../network/kafka/serialization_test.cc | 2 - .../network/kafka/serialization_utilities.h | 2 + 11 files changed, 459 insertions(+), 130 deletions(-) create mode 100644 source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 create mode 100644 test/extensions/filters/network/kafka/request_codec_integration_test.cc delete mode 100644 test/extensions/filters/network/kafka/request_codec_test.cc create mode 100644 test/extensions/filters/network/kafka/request_codec_unit_test.cc diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index eaa6991b35af8..87a1f88978be2 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -38,12 +38,15 @@ class AbstractRequest : public Message { public: AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {}; -protected: + /** + * Request's header + */ const RequestHeader request_header_; }; /** * Concrete request that carries data particular to given request type + * (can be considered a container) */ template class ConcreteRequest : public AbstractRequest { public: diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index ae5396d6531b5..8509f4887646d 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -14,7 +14,7 @@ def main(): COMMAND : 'generate-source', to generate source files 'generate-test', to generate test files OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', - if generate-test: location of 'requests_test.cc' + if generate-test: location of 'requests_test.cc', 'request_codec_request_integration_test.cc' INPUT_FILES: Kafka protocol json files to be processed Kafka spec files are provided at https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message and in Kafka clients jar file @@ -23,11 +23,13 @@ def main(): - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from requests.h When generating test code, it creates: - requests_test.cc - serialization/deserialization tests for kafka structures + - request_codec_request_integration_test.cc - integration test for all request operations using the codec API Templates used are: - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2 - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2 - to create 'requests_test.cc': requests_test_cc.j2 + - to create 'request_codec_request_integration_test.cc' - request_codec_request_integration_test_cc.j2 """ import sys @@ -40,7 +42,8 @@ def main(): input_files = sys.argv[4:] elif 'generate-test' == command: requests_test_cc_file = os.path.abspath(sys.argv[2]) - input_files = sys.argv[3:] + request_codec_request_integration_test_cc_file = os.path.abspath(sys.argv[3]) + input_files = sys.argv[4:] else: raise ValueError('invalid command: ' + command) @@ -76,26 +79,32 @@ def main(): requests_h_contents += request_parsers_template.render(complex_type=request) # full file with headers, namespace declaration etc. - requests_header_template = RenderingHelper.get_template('requests_h.j2') - contents = requests_header_template.render(contents=requests_h_contents) + template = RenderingHelper.get_template('requests_h.j2') + contents = template.render(contents=requests_h_contents) with open(requests_h_file, 'w') as fd: fd.write(contents) - kafka_request_resolver_template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') - contents = kafka_request_resolver_template.render(request_types=requests) + template = RenderingHelper.get_template('kafka_request_resolver_cc.j2') + contents = template.render(request_types=requests) with open(kafka_request_resolver_cc_file, 'w') as fd: fd.write(contents) # test code if 'generate-test' == command: - requests_test_template = RenderingHelper.get_template('requests_test_cc.j2') - contents = requests_test_template.render(request_types=requests) + template = RenderingHelper.get_template('requests_test_cc.j2') + contents = template.render(request_types=requests) with open(requests_test_cc_file, 'w') as fd: fd.write(contents) + template = RenderingHelper.get_template('request_codec_request_integration_test_cc.j2') + contents = template.render(request_types=requests) + + with open(request_codec_request_integration_test_cc_file, 'w') as fd: + fd.write(contents) + def parse_request(spec): """ @@ -185,7 +194,7 @@ def used_fields(self): def constructor_signature(self): """ Return constructor signature - Mutliple versions of the same structure can have identical signatures (due to version bumps in Kafka) + Multiple versions of the same structure can have identical signatures (due to version bumps in Kafka) """ parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) return ', '.join(parameter_spec) diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 new file mode 100644 index 0000000000000..0ac7f4a9c5418 --- /dev/null +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 @@ -0,0 +1,99 @@ +{# + Template for 'request_codec_request_integration_test.cc' + + Provides integration tests using Kafka codec + The only thing happening in these tests is creation of messages, passing them to codec, + and verifying that received parsed values are the same as data send +#} +#include "extensions/filters/network/kafka/request_codec.h" +#include "extensions/filters/network/kafka/requests.h" + +#include "test/mocks/server/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class RequestCodecIntegrationTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; +}; + +class CapturingRequestCallback : public RequestCallback { +public: + virtual void onMessage(MessageSharedPtr request) override; + + const std::vector& getCaptured() const; + +private: + std::vector captured_; +}; + +typedef std::shared_ptr CapturingRequestCallbackSharedPtr; + +void CapturingRequestCallback::onMessage(MessageSharedPtr message) { + captured_.push_back(message); +} + +const std::vector& CapturingRequestCallback::getCaptured() const { + return captured_; +} + +{% for request_type in request_types %} + +// integration test for {{ request_type.name }} messages + +TEST_F(RequestCodecIntegrationTest, shouldHandle{{ request_type.name }}Messages) { + // given + using Request = ConcreteRequest<{{ request_type.name }}>; + + std::vector sent; + int32_t correlation_id = 0; + + {% for field_list in request_type.compute_field_lists() %} + for (int i = 0; i < 100; ++i ) { + const RequestHeader header = { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation_id++, "client-id" }; + const {{ request_type.name }} data = { {{ field_list.example_value() }} }; + const Request request = {header, data}; + putInBuffer(request); + sent.push_back(request); + } + {% endfor %} + + const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); + const RequestParserResolver& request_parser_resolver = RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = std::make_shared(); + + RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; + + // when + testee.onData(buffer_); + + // then + const std::vector& received = request_callback->getCaptured(); + ASSERT_EQ(received.size(), sent.size()); + + for (size_t i = 0; i < received.size(); ++i) { + const std::shared_ptr request = std::dynamic_pointer_cast(received[i]); + ASSERT_NE(request, nullptr); + ASSERT_EQ(*request, sent[i]); + } +} +{% endfor %} + +// misc utilities +template +void RequestCodecIntegrationTest::putInBuffer(const T arg) { + MessageEncoderImpl serializer{buffer_}; + serializer.encode(arg); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 2feaf4fb226b1..cfc577bc3fb2e 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -10,13 +10,24 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +class RequestStartParserFactory : public InitialParserFactory { + + ParserSharedPtr create(const RequestParserResolver& parser_resolver) const override { + return std::make_shared(parser_resolver); + } +}; + +const InitialParserFactory& InitialParserFactory::getDefaultInstance() { + CONSTRUCT_ON_FIRST_USE(RequestStartParserFactory); +} + // convert buffer to slices and pass them to `doParse` void RequestDecoder::onData(Buffer::Instance& data) { uint64_t num_slices = data.getRawSlices(nullptr, 0); STACK_ARRAY(slices, Buffer::RawSlice, num_slices); data.getRawSlices(slices.begin(), num_slices); for (const Buffer::RawSlice& slice : slices) { - doParse(current_parser_, slice); + doParse(slice); } } @@ -25,18 +36,20 @@ void RequestDecoder::onData(Buffer::Instance& data) { * - forward data to current parser * - receive parser response: * -- if still waiting, do nothing - * -- if next parser, replace parser, and keep feeding, if still have data + * -- if next parser, replace current parser, and keep feeding, if still have data * -- if parser message: * --- notify callbacks - * --- replace parser with new start parser, as we are going to parse another request + * --- replace current parser with new start parser, as we are going to parse another request */ -void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice) { +void RequestDecoder::doParse(const Buffer::RawSlice& slice) { const char* bytes = reinterpret_cast(slice.mem_); absl::string_view data = {bytes, slice.len_}; while (!data.empty()) { - ParseResponse result = parser->parse(data); - // this loop guarantees that parsers consuming 0 bytes also get processed + + // feed the data to the parser + ParseResponse result = current_parser_->parse(data); + // this loop guarantees that parsers consuming 0 bytes also get processed in this invocation while (result.hasData()) { if (!result.next_parser_) { @@ -46,12 +59,16 @@ void RequestDecoder::doParse(ParserSharedPtr& parser, const Buffer::RawSlice& sl callback->onMessage(result.message_); } - // we finished parsing this request, start anew - parser = std::make_shared(parser_resolver_); + // as we finished parsing this request, re-initialize the parser + current_parser_ = factory_.create(parser_resolver_); } else { - parser = result.next_parser_; + + // the next parser that's supposed to consume the rest of payload was given + current_parser_ = result.next_parser_; } - result = parser->parse(data); + + // keep parsing the data + result = current_parser_->parse(data); } } } diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index c856b9461fcdb..0150355cc7f6b 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -29,14 +29,32 @@ class RequestCallback { typedef std::shared_ptr RequestCallbackSharedPtr; +/** + * Provides initial parser for messages + * (class extracted to allow injecting test factories) + */ +class InitialParserFactory { +public: + virtual ~InitialParserFactory() = default; + + /** + * Creates default instance that returns RequestStartParser instances + */ + static const InitialParserFactory& getDefaultInstance(); + + /** + * Creates parser with given context + */ + virtual ParserSharedPtr create(const RequestParserResolver& parser_resolver) const PURE; +}; + /** * Decoder that decodes Kafka requests * When a request is decoded, the callbacks are notified, in order * * This decoder uses chain of parsers to parse fragments of a request * Each parser along the line returns the fully parsed message or the next parser - * Stores parse state (have `onData` invoked multiple times for messages that are larger than single - * buffer) + * Stores parse state (as large message's payload can be provided through multiple `onData` calls) */ class RequestDecoder : public MessageDecoder { public: @@ -48,8 +66,16 @@ class RequestDecoder : public MessageDecoder { */ RequestDecoder(const RequestParserResolver& parserResolver, const std::vector callbacks) - : parser_resolver_{parserResolver}, callbacks_{callbacks}, - current_parser_{new RequestStartParser(parser_resolver_)} {}; + : RequestDecoder(InitialParserFactory::getDefaultInstance(), parserResolver, callbacks){}; + + /** + * Visible for testing + * Allows injecting initial parser factory + */ + RequestDecoder(const InitialParserFactory& factory, const RequestParserResolver& parserResolver, + const std::vector callbacks) + : factory_{factory}, parser_resolver_{parserResolver}, callbacks_{callbacks}, + current_parser_{factory_.create(parser_resolver_)} {}; /** * Consumes all data present in a buffer @@ -60,9 +86,12 @@ class RequestDecoder : public MessageDecoder { void onData(Buffer::Instance& data) override; private: - void doParse(ParserSharedPtr& parser, const Buffer::RawSlice& slice); + void doParse(const Buffer::RawSlice& slice); + + const InitialParserFactory& factory_; const RequestParserResolver& parser_resolver_; + const std::vector callbacks_; ParserSharedPtr current_parser_; diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index 6b9af10d0b776..57b45ad21267b 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -16,6 +16,7 @@ envoy_cc_test_library( name = "serialization_utilities_lib", hdrs = ["serialization_utilities.h"], deps = [ + "//source/common/buffer:buffer_lib", "//source/extensions/filters/network/kafka:serialization_lib", ], ) @@ -67,8 +68,28 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "request_codec_test", - srcs = ["request_codec_test.cc"], + name = "request_codec_unit_test", + srcs = ["request_codec_unit_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_integration_test", + srcs = ["request_codec_integration_test.cc"], + extension_name = "envoy.filters.network.kafka", + deps = [ + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", + "//test/mocks/server:server_mocks", + ], +) + +envoy_extension_cc_test( + name = "request_codec_request_integration_test", + srcs = ["request_codec_request_integration_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ "//source/extensions/filters/network/kafka:kafka_request_codec_lib", @@ -92,10 +113,13 @@ genrule( "@kafka_produce_request_spec//file", "@kafka_fetch_request_spec//file", ], - outs = ["requests_test.cc"], + outs = [ + "requests_test.cc", + "request_codec_request_integration_test.cc", + ], cmd = """ ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ - $(location requests_test.cc) \ + $(location requests_test.cc) $(location request_codec_request_integration_test.cc) \ $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file) """, tools = [ diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc new file mode 100644 index 0000000000000..7eb51fb2dfe48 --- /dev/null +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -0,0 +1,82 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gtest/gtest.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class RequestDecoderTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; +}; + +class CapturingRequestCallback : public RequestCallback { +public: + virtual void onMessage(MessageSharedPtr request) override; + + const std::vector& getCaptured() const; + +private: + std::vector captured_; +}; + +typedef std::shared_ptr CapturingRequestCallbackSharedPtr; + +void CapturingRequestCallback::onMessage(MessageSharedPtr message) { captured_.push_back(message); } + +const std::vector& CapturingRequestCallback::getCaptured() const { + return captured_; +} + +TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { + // given + // api keys have values below 100, so the messages generated in this loop should not be recognized + const int16_t base_api_key = 0; + std::vector sent_headers; + for (int16_t i = 0; i < 1000; ++i) { + const int16_t api_key = static_cast(base_api_key + i); + const RequestHeader header = {api_key, 0, 0, "client-id"}; + const std::vector data = std::vector(1024); + putInBuffer(ConcreteRequest>{header, data}); + sent_headers.push_back(header); + } + + const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); + const RequestParserResolver& request_parser_resolver = + RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = + std::make_shared(); + + RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; + + // when + testee.onData(buffer_); + + // then + const std::vector& received = request_callback->getCaptured(); + ASSERT_EQ(received.size(), sent_headers.size()); + + for (size_t i = 0; i < received.size(); ++i) { + const std::shared_ptr request = + std::dynamic_pointer_cast(received[i]); + ASSERT_NE(request, nullptr); + ASSERT_EQ(request->request_header_, sent_headers[i]); + } +} + +// misc utilities +template void RequestDecoderTest::putInBuffer(T arg) { + MessageEncoderImpl serializer{buffer_}; + serializer.encode(arg); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_test.cc b/test/extensions/filters/network/kafka/request_codec_test.cc deleted file mode 100644 index 4cab50a825d12..0000000000000 --- a/test/extensions/filters/network/kafka/request_codec_test.cc +++ /dev/null @@ -1,99 +0,0 @@ -#include "extensions/filters/network/kafka/request_codec.h" - -#include "test/mocks/server/mocks.h" - -#include "gmock/gmock.h" -#include "gtest/gtest.h" - -using testing::_; - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -class RequestDecoderTest : public testing::Test { -public: - Buffer::OwnedImpl buffer_; - - template std::shared_ptr serializeAndDeserialize(T request); -}; - -class MockMessageListener : public RequestCallback { -public: - MOCK_METHOD1(onMessage, void(MessageSharedPtr)); -}; - -class MockRequestParserResolver : public RequestParserResolver { -public: - MockRequestParserResolver() : RequestParserResolver({}){}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); -}; - -template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(request); - - std::shared_ptr mock_listener = std::make_shared(); - RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; - - MessageSharedPtr receivedMessage; - EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); - - testee.onData(buffer_); - - return std::dynamic_pointer_cast(receivedMessage); -}; - -ParserSharedPtr createSentinelParser(testing::Unused, testing::Unused, - RequestContextSharedPtr context) { - return std::make_shared(context); -} - -struct MockRequest { - const int32_t field1_ = 1; - const int64_t field2_ = 2; - - size_t encode(Buffer::Instance& buffer, EncodingContext& encoder) const { - size_t written{0}; - written += encoder.encode(field1_, buffer); - written += encoder.encode(field2_, buffer); - return written; - } - - friend std::ostream& operator<<(std::ostream& os, const MockRequest&) { - return os << "{MockRequest}"; - }; -}; - -TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { - // given - MessageEncoderImpl serializer{buffer_}; - MockRequest data{}; - // api key & version values do not matter, as resolver recognizes nothing - ConcreteRequest request = {{1000, 2000, 3000, "correlation-id"}, data}; - - serializer.encode(request); - - MockRequestParserResolver mock_parser_resolver{}; - EXPECT_CALL(mock_parser_resolver, createParser(_, _, _)) - .WillOnce(testing::Invoke(createSentinelParser)); - std::shared_ptr mock_listener = std::make_shared(); - RequestDecoder testee{mock_parser_resolver, {mock_listener}}; - - MessageSharedPtr rev; - EXPECT_CALL(*mock_listener, onMessage(_)).WillOnce(testing::SaveArg<0>(&rev)); - - // when - testee.onData(buffer_); - - // then - ASSERT_NE(rev, nullptr); - auto received = std::dynamic_pointer_cast(rev); - ASSERT_NE(received, nullptr); -} - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc new file mode 100644 index 0000000000000..0feb02fec4932 --- /dev/null +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -0,0 +1,165 @@ +#include "extensions/filters/network/kafka/request_codec.h" + +#include "test/mocks/server/mocks.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::AnyNumber; +using testing::Eq; +using testing::Invoke; +using testing::ResultOf; +using testing::Return; + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +class MockParserFactory : public InitialParserFactory { +public: + MOCK_CONST_METHOD1(create, ParserSharedPtr(const RequestParserResolver&)); +}; + +class MockParser : public Parser { +public: + MOCK_METHOD1(parse, ParseResponse(absl::string_view&)); +}; + +typedef std::shared_ptr MockParserSharedPtr; + +class MockRequestParserResolver : public RequestParserResolver { +public: + MockRequestParserResolver() : RequestParserResolver({}){}; + MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); +}; + +class MockRequestCallback : public RequestCallback { +public: + MOCK_METHOD1(onMessage, void(MessageSharedPtr)); +}; + +typedef std::shared_ptr MockRequestCallbackSharedPtr; + +class RequestDecoderTest : public testing::Test { +protected: + template void putInBuffer(T arg); + + Buffer::OwnedImpl buffer_; + + MockParserFactory initial_parser_factory_{}; + MockRequestParserResolver parser_resolver_{}; + MockRequestCallbackSharedPtr request_callback_{std::make_shared()}; +}; + +ParseResponse consumeOneByte(absl::string_view& data) { + data = {data.data() + 1, data.size() - 1}; + return ParseResponse::stillWaiting(); +} + +TEST_F(RequestDecoderTest, shouldDoNothingIfParserNeverReturnsMessage) { + // given + putInBuffer(ConcreteRequest{{}, 0}); + + MockParserSharedPtr parser = std::make_shared(); + EXPECT_CALL(*parser, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)).WillOnce(Return(parser)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then - request_callback is not interacted with +} + +TEST_F(RequestDecoderTest, shouldUseNewParserAsResponse) { + // given + putInBuffer(ConcreteRequest{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + MockParserSharedPtr parser2 = std::make_shared(); + MockParserSharedPtr parser3 = std::make_shared(); + EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(ParseResponse::nextParser(parser2))); + EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(ParseResponse::nextParser(parser3))); + EXPECT_CALL(*parser3, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)).WillOnce(Return(parser1)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then - request_callback is not interacted with +} + +TEST_F(RequestDecoderTest, shouldReturnParsedMessageAndReinitialize) { + // given + putInBuffer(ConcreteRequest{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + MessageSharedPtr message = std::make_shared(RequestHeader{}); + EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(ParseResponse::parsedMessage(message))); + + MockParserSharedPtr parser2 = std::make_shared(); + EXPECT_CALL(*parser2, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser2)); + + EXPECT_CALL(*request_callback_, onMessage(message)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then - request_callback got notified only once +} + +TEST_F(RequestDecoderTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { + // given + putInBuffer(ConcreteRequest{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + MockParserSharedPtr parser2 = std::make_shared(); + MockParserSharedPtr parser3 = std::make_shared(); + + auto consume_and_return = [this, &parser2](absl::string_view& data) -> ParseResponse { + data = {data.data() + buffer_.length(), data.size() - buffer_.length()}; + return ParseResponse::nextParser(parser2); + }; + EXPECT_CALL(*parser1, parse(_)).WillOnce(Invoke(consume_and_return)); + MessageSharedPtr message = std::make_shared(RequestHeader{}); + EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(ParseResponse::parsedMessage(message))); + EXPECT_CALL(*parser3, parse(ResultOf([](absl::string_view arg) { return arg.size(); }, Eq(0)))) + .WillOnce(Return(ParseResponse::stillWaiting())); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser3)); + + EXPECT_CALL(*request_callback_, onMessage(message)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then - parser3 was given only empty data (size 0) +} + +// misc utilities +template void RequestDecoderTest::putInBuffer(T arg) { + MessageEncoderImpl serializer{buffer_}; + serializer.encode(arg); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 20cd1a4cb71fb..6eea87453ff2f 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -1,5 +1,3 @@ -#include "extensions/filters/network/kafka/serialization.h" - #include "test/extensions/filters/network/kafka/serialization_utilities.h" namespace Envoy { diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h index f7674c1548530..8135d79a48d56 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.h +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -3,6 +3,8 @@ #include "common/buffer/buffer_impl.h" #include "common/common/stack_array.h" +#include "extensions/filters/network/kafka/serialization.h" + #include "absl/strings/string_view.h" #include "gtest/gtest.h" From 28641b4849a2fb8681d5cd68ca1974ca3f1cd8f6 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 19 Mar 2019 13:49:57 -0700 Subject: [PATCH 17/29] Download whole Kafka specification; test fixes Signed-off-by: Adam Kotwasinski --- api/bazel/repositories.bzl | 26 ++++++++++++------- api/bazel/repository_locations.bzl | 7 +++++ source/extensions/filters/network/kafka/BUILD | 5 ++-- test/extensions/filters/network/kafka/BUILD | 5 ++-- .../kafka/request_codec_integration_test.cc | 3 ++- 5 files changed, 29 insertions(+), 17 deletions(-) diff --git a/api/bazel/repositories.bzl b/api/bazel/repositories.bzl index 3b938b86e4ca0..62cd26e4f445d 100644 --- a/api/bazel/repositories.bzl +++ b/api/bazel/repositories.bzl @@ -1,4 +1,3 @@ -load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_file") load("@bazel_tools//tools/build_defs/repo:http.bzl", "http_archive") load(":envoy_http_archive.bzl", "envoy_http_archive") load(":repository_locations.bzl", "REPOSITORY_LOCATIONS") @@ -32,15 +31,10 @@ def api_dependencies(): locations = REPOSITORY_LOCATIONS, build_file_content = OPENCENSUSTRACE_BUILD_CONTENT, ) - http_file( - name = "kafka_produce_request_spec", - sha256 = "e035f70a136ef5a5ef2ff17b52dc10f2eae4ac596639689f5584054909d5816f", - urls = ["https://raw.githubusercontent.com/apache/kafka/2.2/clients/src/main/resources/common/message/ProduceRequest.json"], - ) - http_file( - name = "kafka_fetch_request_spec", - sha256 = "9209b68fe0295818071c2f644363cf71e6443eb61f8e9d2636412876c5e2bae8", - urls = ["https://raw.githubusercontent.com/apache/kafka/2.2/clients/src/main/resources/common/message/FetchRequest.json"], + envoy_http_archive( + name = "kafka_source", + locations = REPOSITORY_LOCATIONS, + build_file_content = KAFKASOURCE_BUILD_CONTENT, ) GOOGLEAPIS_BUILD_CONTENT = """ @@ -296,3 +290,15 @@ go_proto_library( visibility = ["//visibility:public"], ) """ + +KAFKASOURCE_BUILD_CONTENT = """ + +filegroup( + name = "request_protocol_files", + srcs = glob([ + "*Request.json", + ]), + visibility = ["//visibility:public"], +) + +""" diff --git a/api/bazel/repository_locations.bzl b/api/bazel/repository_locations.bzl index 6d68524399fad..a2f08b621a75b 100644 --- a/api/bazel/repository_locations.bzl +++ b/api/bazel/repository_locations.bzl @@ -16,6 +16,8 @@ GOOGLEAPIS_SHA = "16f5b2e8bf1e747a32f9a62e211f8f33c94645492e9bbd72458061d9a9de1f PROMETHEUS_GIT_SHA = "99fa1f4be8e564e8a6b613da7fa6f46c9edafc6c" # Nov 17, 2017 PROMETHEUS_SHA = "783bdaf8ee0464b35ec0c8704871e1e72afa0005c3f3587f65d9d6694bf3911b" +KAFKA_SOURCE_SHA = "ae7a1696c0a0302b43c5b21e515c37e6ecd365941f68a510a7e442eebddf39a1" # 2.2.0-rc2 + REPOSITORY_LOCATIONS = dict( bazel_skylib = dict( sha256 = BAZEL_SKYLIB_SHA256, @@ -48,4 +50,9 @@ REPOSITORY_LOCATIONS = dict( strip_prefix = "opencensus-proto-" + OPENCENSUS_RELEASE + "/src/opencensus/proto/trace/v1", urls = ["https://github.com/census-instrumentation/opencensus-proto/archive/v" + OPENCENSUS_RELEASE + ".tar.gz"], ), + kafka_source = dict( + sha256 = KAFKA_SOURCE_SHA, + strip_prefix = "kafka-2.2.0-rc2/clients/src/main/resources/common/message", + urls = ["https://github.com/apache/kafka/archive/2.2.0-rc2.zip"], + ), ) diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index c3a63cd0ccf03..f4cc3ce6bb8b4 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -47,8 +47,7 @@ envoy_cc_library( genrule( name = "kafka_generated_source", srcs = [ - "@kafka_produce_request_spec//file", - "@kafka_fetch_request_spec//file", + "@kafka_source//:request_protocol_files", ], outs = [ "requests.h", @@ -57,7 +56,7 @@ genrule( cmd = """ ./$(location :kafka_code_generator) generate-source \ $(location requests.h) $(location kafka_request_resolver.cc) \ - $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file) + $(SRCS) """, tools = [ ":kafka_code_generator", diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index 57b45ad21267b..07a57856a052a 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -110,8 +110,7 @@ envoy_extension_cc_test( genrule( name = "requests_test_generator", srcs = [ - "@kafka_produce_request_spec//file", - "@kafka_fetch_request_spec//file", + "@kafka_source//:request_protocol_files", ], outs = [ "requests_test.cc", @@ -120,7 +119,7 @@ genrule( cmd = """ ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ $(location requests_test.cc) $(location request_codec_request_integration_test.cc) \ - $(location @kafka_produce_request_spec//file) $(location @kafka_fetch_request_spec//file) + $(SRCS) """, tools = [ "//source/extensions/filters/network/kafka:kafka_code_generator", diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index 7eb51fb2dfe48..fdd56b0f75199 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -34,10 +34,11 @@ const std::vector& CapturingRequestCallback::getCaptured() con return captured_; } +// other request types are tested in (generated) 'request_codec_request_integration_test.cc' TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given // api keys have values below 100, so the messages generated in this loop should not be recognized - const int16_t base_api_key = 0; + const int16_t base_api_key = 100; std::vector sent_headers; for (int16_t i = 0; i < 1000; ++i) { const int16_t api_key = static_cast(base_api_key + i); From 3b9e324e97873bffe440d0a6a7fa86fcd41bd4d4 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Fri, 22 Mar 2019 11:37:35 -0700 Subject: [PATCH 18/29] Activate kafka tests in builds; review fixes: documentation, formatting Signed-off-by: Adam Kotwasinski --- source/extensions/extensions_build_config.bzl | 1 + .../extensions/filters/network/kafka/codec.h | 8 +- .../filters/network/kafka/kafka_request.h | 30 ++-- .../network/kafka/kafka_request_parser.cc | 4 +- .../network/kafka/kafka_request_parser.h | 51 +++--- .../filters/network/kafka/kafka_types.h | 8 +- .../filters/network/kafka/message.h | 4 +- .../extensions/filters/network/kafka/parser.h | 20 +-- .../complex_type_template.j2 | 90 +++++----- .../kafka_generator.py | 123 +++++++------- .../kafka_request_resolver_cc.j2 | 9 +- ...quest_codec_request_integration_test_cc.j2 | 104 ++++++------ .../protocol_code_generator/request_parser.j2 | 14 +- .../protocol_code_generator/requests_h.j2 | 37 +++-- .../requests_test_cc.j2 | 17 +- .../filters/network/kafka/request_codec.cc | 38 ++--- .../filters/network/kafka/request_codec.h | 4 +- .../filters/network/kafka/serialization.h | 156 +++++++++--------- .../serialization_composite_generator.py | 28 ++-- .../serialization_composite_h.j2 | 77 +++++---- .../serialization_composite_test_cc.j2 | 80 ++++----- test/extensions/filters/network/kafka/BUILD | 4 +- .../kafka/kafka_request_parser_test.cc | 17 +- .../kafka/request_codec_integration_test.cc | 7 +- .../network/kafka/request_codec_unit_test.cc | 15 +- .../network/kafka/serialization_test.cc | 20 +-- .../network/kafka/serialization_utilities.h | 28 ++-- 27 files changed, 503 insertions(+), 491 deletions(-) diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index bd3793cc9d16a..bc6d0667569ea 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -69,6 +69,7 @@ EXTENSIONS = { "envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", "envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", "envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", + "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:config", "envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", "envoy.filters.network.mysql_proxy": "//source/extensions/filters/network/mysql_proxy:config", "envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index 8aabd4d620d50..01b9a3c84ad15 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -11,28 +11,28 @@ namespace NetworkFilters { namespace Kafka { /** - * Kafka message decoder + * Kafka message decoder. */ class MessageDecoder { public: virtual ~MessageDecoder() = default; /** - * Processes given buffer attempting to decode messages contained within + * Processes given buffer attempting to decode messages contained within. * @param data buffer instance */ virtual void onData(Buffer::Instance& data) PURE; }; /** - * Kafka message encoder + * Kafka message encoder. */ class MessageEncoder { public: virtual ~MessageEncoder() = default; /** - * Encodes given message + * Encodes given message. * @param message message to be encoded */ virtual void encode(const Message& message) PURE; diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 87a1f88978be2..759d0ddd6a7e7 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -1,7 +1,5 @@ #pragma once -#include - #include "envoy/common/exception.h" #include "extensions/filters/network/kafka/message.h" @@ -14,7 +12,7 @@ namespace NetworkFilters { namespace Kafka { /** - * Represents fields that are present in every Kafka request message + * Represents fields that are present in every Kafka request message. * @see http://kafka.apache.org/protocol.html#protocol_messages */ struct RequestHeader { @@ -30,8 +28,8 @@ struct RequestHeader { }; /** - * Abstract Kafka request - * Contains data present in every request (the header with request key, version, etc.) + * Abstract Kafka request. + * Contains data present in every request (the header with request key, version, etc.). * @see http://kafka.apache.org/protocol.html#protocol_messages */ class AbstractRequest : public Message { @@ -39,35 +37,34 @@ class AbstractRequest : public Message { AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {}; /** - * Request's header + * Request's header. */ const RequestHeader request_header_; }; /** - * Concrete request that carries data particular to given request type - * (can be considered a container) + * Concrete request that carries data particular to given request type. */ template class ConcreteRequest : public AbstractRequest { public: /** - * Request header fields need to be initialized by user in case of newly created requests + * Request header fields need to be initialized by user in case of newly created requests. */ ConcreteRequest(const RequestHeader& request_header, const RequestData& data) : AbstractRequest{request_header}, data_{data} {}; /** - * Encodes given request into a buffer, with any extra configuration carried by the context + * Encodes given request into a buffer, with any extra configuration carried by the context. */ size_t encode(Buffer::Instance& dst) const override { EncodingContext context{request_header_.api_version_}; size_t written{0}; - // encode request header + // Encode request header. written += context.encode(request_header_.api_key_, dst); written += context.encode(request_header_.api_version_, dst); written += context.encode(request_header_.correlation_id_, dst); written += context.encode(request_header_.client_id_, dst); - // encode request-specific data + // Encode request-specific data. written += context.encode(data_, dst); return written; } @@ -82,15 +79,16 @@ template class ConcreteRequest : public AbstractRequest { /** * Request that did not have api_key & api_version that could be matched with any of - * request-specific parsers + * request-specific parsers. + * Right now it acts as a placeholder only, and does not carry the request data. */ class UnknownRequest : public AbstractRequest { public: UnknownRequest(const RequestHeader& request_header) : AbstractRequest{request_header} {}; - // this isn't the prettiest, as we have thrown away the data - // TODO(adamkotwasinski) discuss capturing the data as-is, and simply putting it back - // this would add ability to forward unknown types of requests in cluster-proxy + /** + * It is impossible to encode unknown request, as it is only a placeholder. + */ size_t encode(Buffer::Instance&) const override { throw EnvoyException("cannot serialize unknown request"); } diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc index d797b8dc07ca3..db45d33a05d83 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -25,8 +25,8 @@ ParseResponse RequestHeaderParser::parse(absl::string_view& data) { try { context_->remaining_request_size_ -= deserializer_->feed(data); } catch (const EnvoyException& e) { - // unable to compute request header, but we still need to consume rest of request (some of the - // data might have been consumed) + // We were unable to compute the request header, but we still need to consume rest of request + // (some of the data might have been consumed during this attempt). const int32_t consumed = static_cast(orig_data.size() - data.size()); context_->remaining_request_size_ -= consumed; context_->request_header_ = {-1, -1, -1, absl::nullopt}; diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h index cc29f8a655d4e..404a37b2a6fff 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.h +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -15,7 +15,7 @@ namespace NetworkFilters { namespace Kafka { /** - * Context that is shared between parsers that are handling the same single message + * Context that is shared between parsers that are handling the same single message. */ struct RequestContext { int32_t remaining_request_size_{0}; @@ -25,16 +25,16 @@ struct RequestContext { typedef std::shared_ptr RequestContextSharedPtr; /** - * Configuration object - * Resolves the parser that will be responsible for consuming the request-specific data - * In other words: provides (api_key, api_version) -> Parser function + * Request decoder configuration object. + * Resolves the parser that will be responsible for consuming the request-specific data. + * In other words: provides (api_key, api_version) -> Parser function. */ class RequestParserResolver { public: virtual ~RequestParserResolver() = default; /** - * Creates a parser that is going to process data specific for given api_key & api_version + * Creates a parser that is going to process data specific for given api_key & api_version. * @param api_key request type * @param api_version request version * @param context context to be used by parser @@ -44,13 +44,13 @@ class RequestParserResolver { RequestContextSharedPtr context) const; /** - * Return default resolver, that uses request's api key and version to provide a matching parser + * Return default resolver, that uses request's api key and version to provide a matching parser. */ static const RequestParserResolver& getDefaultInstance(); }; /** - * Request parser responsible for consuming request length and setting up context with this data + * Request parser responsible for consuming request length and setting up context with this data. * @see http://kafka.apache.org/protocol.html#protocol_common */ class RequestStartParser : public Parser { @@ -59,7 +59,7 @@ class RequestStartParser : public Parser { : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; /** - * Consumes 4 bytes (INT32) as request length and updates the context with that value + * Consumes 4 bytes (INT32) as request length and updates the context with that value. * @return RequestHeaderParser instance to process request header */ ParseResponse parse(absl::string_view& data) override; @@ -73,8 +73,8 @@ class RequestStartParser : public Parser { }; /** - * Deserializer that extracts request header (4 fields) - * Can throw, as one of the fields (client-id) can throw (nullable string with invalid length) + * Deserializer that extracts request header (4 fields). + * Can throw, as one of the fields (client-id) can throw (nullable string with invalid length). * @see http://kafka.apache.org/protocol.html#protocol_messages */ class RequestHeaderDeserializer @@ -85,25 +85,26 @@ class RequestHeaderDeserializer typedef std::unique_ptr RequestHeaderDeserializerPtr; /** - * Parser responsible for computing request header and updating the context with data resolved - * On a successful parse uses resolved data (api_key & api_version) to determine next parser. + * Parser responsible for extracting the request header and putting it into context. + * On a successful parse the resolved data (api_key & api_version) is used to determine the next + * parser. * @see http://kafka.apache.org/protocol.html#protocol_messages */ class RequestHeaderParser : public Parser { public: - // default constructor + // Default constructor. RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) : RequestHeaderParser{parser_resolver, context, std::make_unique()} {}; - // visible for testing + // Constructor visible for testing (allows for initial parser injection). RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context, RequestHeaderDeserializerPtr deserializer) : parser_resolver_{parser_resolver}, context_{context}, deserializer_{ std::move(deserializer)} {}; /** - * Uses data provided to compute request header + * Uses data provided to compute request header. * @return Parser instance responsible for processing rest of the message */ ParseResponse parse(absl::string_view& data) override; @@ -119,14 +120,14 @@ class RequestHeaderParser : public Parser { /** * Sentinel parser that is responsible for consuming message bytes for messages that had unsupported * api_key & api_version. It does not attempt to capture any data, just throws it away until end of - * message + * message. */ class SentinelParser : public Parser { public: SentinelParser(RequestContextSharedPtr context) : context_{context} {}; /** - * Returns UnknownRequest + * Returns UnknownRequest. Ignores (jumps over) the data provided. */ ParseResponse parse(absl::string_view& data) override; @@ -137,9 +138,9 @@ class SentinelParser : public Parser { }; /** - * Request parser uses a single deserializer to construct a request object + * Request parser uses a single deserializer to construct a request object. * This parser is responsible for consuming request-specific data (e.g. topic names) and always - * returns a parsed message + * returns a parsed message. * @param RequestType request class * @param DeserializerType deserializer type corresponding to request class (should be subclass of * Deserializer) @@ -147,27 +148,27 @@ class SentinelParser : public Parser { template class RequestParser : public Parser { public: /** - * Create a parser with given context + * Create a parser with given context. * @param context parse context containing request header */ RequestParser(RequestContextSharedPtr context) : context_{context} {}; /** - * Consume enough data to fill in deserializer and receive the parsed request - * Fill in request's header with data stored in context + * Consume enough data to fill in deserializer and receive the parsed request. + * Fill in request's header with data stored in context. */ ParseResponse parse(absl::string_view& data) override { context_->remaining_request_size_ -= deserializer.feed(data); if (deserializer.ready()) { if (0 == context_->remaining_request_size_) { - // after a successful parse, there should be nothing left - we have consumed all the bytes + // After a successful parse, there should be nothing left - we have consumed all the bytes. MessageSharedPtr msg = std::make_shared>( context_->request_header_, deserializer.get()); return ParseResponse::parsedMessage(msg); } else { - // the message makes no sense, the deserializer that matches the schema consumed all - // necessary data, but there's still unconsumed bytes + // The message makes no sense, the deserializer that matches the schema consumed all + // necessary data, but there are still bytes in this message. return ParseResponse::nextParser(std::make_shared(context_)); } } else { diff --git a/source/extensions/filters/network/kafka/kafka_types.h b/source/extensions/filters/network/kafka/kafka_types.h index 1aa32106cb390..71d1ce920a82d 100644 --- a/source/extensions/filters/network/kafka/kafka_types.h +++ b/source/extensions/filters/network/kafka/kafka_types.h @@ -12,22 +12,22 @@ namespace NetworkFilters { namespace Kafka { /** - * Nullable string used by Kafka + * Nullable string used by Kafka. */ typedef absl::optional NullableString; /** - * Bytes array used by Kafka + * Bytes array used by Kafka. */ typedef std::vector Bytes; /** - * Nullable bytes array used by Kafka + * Nullable bytes array used by Kafka. */ typedef absl::optional NullableBytes; /** - * Kafka array of elements of type T + * Kafka array of elements of type T. */ template using NullableArray = absl::optional>; diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h index eec046295d002..e6747ad7f453c 100644 --- a/source/extensions/filters/network/kafka/message.h +++ b/source/extensions/filters/network/kafka/message.h @@ -12,14 +12,14 @@ namespace NetworkFilters { namespace Kafka { /** - * Abstract message (that can be either request or response) + * Abstract message (that can be either request or response). */ class Message { public: virtual ~Message() = default; /** - * Encode the contents of this message into a given buffer + * Encode the contents of this message into a given buffer. * @param dst buffer instance to keep serialized message */ virtual size_t encode(Buffer::Instance& dst) const PURE; diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h index 9cef1eb19122f..2ef06dbc2b1fc 100644 --- a/source/extensions/filters/network/kafka/parser.h +++ b/source/extensions/filters/network/kafka/parser.h @@ -18,9 +18,7 @@ class ParseResponse; /** * Parser is responsible for consuming data relevant to some part of a message, and then returning - * the decision how the parsing should continue impl note: better name could be Consumer, but really - * don't want to use that word considering that it's so prevalent in Kafka world; suggestions - * welcome + * the decision how the parsing should continue. */ class Parser : public Logger::Loggable { public: @@ -28,7 +26,7 @@ class Parser : public Logger::Loggable { /** * Submit data to be processed by parser, will consume as much data as it is necessary to reach - * the conclusion what should be the next parse step + * the conclusion what should be the next parse step. * @param data bytes to be processed, will be updated by parser if any have been consumed * @return parse status - decision what should be done with current parser (keep/replace) */ @@ -39,32 +37,32 @@ typedef std::shared_ptr ParserSharedPtr; /** * Three-state holder representing one of: - * - parser still needs data (`stillWaiting`) + * - parser still needs data (`stillWaiting`), * - parser is finished, and following parser should be used to process the rest of data - * (`nextParser`) - * - parser is finished, and fully-parsed message is attached (`parsedMessage`) + * (`nextParser`), + * - parser is finished, and fully-parsed message is attached (`parsedMessage`). */ class ParseResponse { public: /** - * Constructs a response that states that parser still needs data and should not be replaced + * Constructs a response that states that parser still needs data and should not be replaced. */ static ParseResponse stillWaiting() { return {nullptr, nullptr}; } /** * Constructs a response that states that parser is finished and should be replaced by given - * parser + * parser. */ static ParseResponse nextParser(ParserSharedPtr next_parser) { return {next_parser, nullptr}; }; /** * Constructs a response that states that parser is finished, the message is ready, and parsing - * can start anew for next message + * can start anew for next message. */ static ParseResponse parsedMessage(MessageSharedPtr message) { return {nullptr, message}; }; /** - * If response contains a next parser or the fully parsed message + * If response contains a next parser or the fully parsed message. */ bool hasData() const { return (next_parser_ != nullptr) || (message_ != nullptr); } diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 index 289080b08749b..2395c39ddaab4 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 @@ -1,58 +1,62 @@ {# - Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest, FetchRequestTopic, FetchRequestPartition) - Rendered templates for each structure in Kafka protocol will be put into 'requests.h' file + Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest, FetchRequestTopic). + Rendered templates for each structure in Kafka protocol will be put into 'requests.h' file. - Each structure is capable of holding all versions of given entity (what means its fields are actually a superset of union of all versions' fields) - Each version has a dedicated deserializer (named $requestV$versionDeserializer), which calls the matching constructor + Each structure is capable of holding all versions of given entity (what means its fields are actually a superset + of union of all versions' fields). Each version has a dedicated deserializer (named $requestV$versionDeserializer), + which calls the matching constructor. - To serialize, it is necessary to pass the encoding context (that contains the version that's being serialized) - Depending on the version, the fields will be written to the buffer + To serialize, it is necessary to pass the encoding context (that contains the version that's being serialized). + Depending on the version, the fields will be written to the buffer. #} struct {{ complex_type.name }} { - {# - Constructors invoked by deserializers - Each constructor has a signature that matches the fields in at least one version - (sometimes there are different Kafka versions that are actually composed of precisely the same fields) - #} - {% for field in complex_type.fields %} - const {{ field.field_declaration() }}_;{% endfor %} - {% for constructor in complex_type.compute_constructors() %} - // constructor used in versions: {{ constructor['versions'] }} - {{ constructor['full_declaration'] }}{% endfor %} - - {# For every field that's used in version, just serialize it #} - {% if complex_type.fields|length > 0 %} - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - const int16_t api_version = encoder.apiVersion(); - size_t written{0};{% for field in complex_type.fields %} - if (api_version >= {{ field.version_usage[0] }} && api_version < {{ field.version_usage[-1] + 1 }}) { - written += encoder.encode({{ field.name }}_, dst); - }{% endfor %} - return written; - } - {% else %} - size_t encode(Buffer::Instance&, EncodingContext&) const { - return 0; - } - {% endif %} - - {% if complex_type.fields|length > 0 %} - bool operator==(const {{ complex_type.name }}& rhs) const { - {% else %} - bool operator==(const {{ complex_type.name }}&) const { - {% endif %} - return true{% for field in complex_type.fields %} - && {{ field.name }}_ == rhs.{{ field.name }}_{% endfor %}; - }; + {# + Constructors invoked by deserializers. + Each constructor has a signature that matches the fields in at least one version (as sometimes there are + different Kafka versions that are actually composed of precisely the same fields). + #} + {% for field in complex_type.fields %} + const {{ field.field_declaration() }}_;{% endfor %} + {% for constructor in complex_type.compute_constructors() %} + // constructor used in versions: {{ constructor['versions'] }} + {{ constructor['full_declaration'] }}{% endfor %} + + {# For every field that's used in version, just serialize it. #} + {% if complex_type.fields|length > 0 %} + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + const int16_t api_version = encoder.apiVersion(); + size_t written{0};{% for field in complex_type.fields %} + if (api_version >= {{ field.version_usage[0] }} && api_version < {{ field.version_usage[-1] + 1 }}) { + written += encoder.encode({{ field.name }}_, dst); + }{% endfor %} + return written; + } + {% else %} + size_t encode(Buffer::Instance&, EncodingContext&) const { + return 0; + } + {% endif %} + + {% if complex_type.fields|length > 0 %} + bool operator==(const {{ complex_type.name }}& rhs) const { + {% else %} + bool operator==(const {{ complex_type.name }}&) const { + {% endif %} + return true{% for field in complex_type.fields %} + && {{ field.name }}_ == rhs.{{ field.name }}_{% endfor %}; + }; }; {# - Each structure version has a deserializer that matches the structure's field list + Each structure version has a deserializer that matches the structure's field list. #} {% for field_list in complex_type.compute_field_lists() %} class {{ complex_type.name }}V{{ field_list.version }}Deserializer: - public CompositeDeserializerWith{{ field_list.field_count() }}Delegates<{{ complex_type.name }}{% for field in field_list.used_fields() %}, {{ field.deserializer_name_in_version(field_list.version) }}{% endfor %}>{}; + public CompositeDeserializerWith{{ field_list.field_count() }}Delegates< + {{ complex_type.name }} + {% for field in field_list.used_fields() %}, {{ field.deserializer_name_in_version(field_list.version) }} + {% endfor %}>{}; {% endfor %} diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index 8509f4887646d..b505f8706888d 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -5,31 +5,31 @@ def main(): """ Kafka header generator script ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Generates C++ headers from Kafka protocol specification - Can generate both main source code, as well as test code + Generates C++ headers from Kafka protocol specification. + Can generate both main source code, as well as test code. Usage: - kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES + kafka_generator.py COMMAND OUTPUT FILES INPUT_FILES where: - COMMAND : 'generate-source', to generate source files - 'generate-test', to generate test files + COMMAND : 'generate-source', to generate source files, + 'generate-test', to generate test files. OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', - if generate-test: location of 'requests_test.cc', 'request_codec_request_integration_test.cc' - INPUT_FILES: Kafka protocol json files to be processed + if generate-test: location of 'requests_test.cc', 'request_codec_request_integration_test.cc'. + INPUT_FILES: Kafka protocol json files to be processed. - Kafka spec files are provided at https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message and in Kafka clients jar file + Kafka spec files are provided in Kafka clients jar file. When generating source code, it creates: - - requests.h - definition of all the structures/deserializers/parsers related to Kafka requests - - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from requests.h + - requests.h - definition of all the structures/deserializers/parsers related to Kafka requests, + - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from requests.h. When generating test code, it creates: - - requests_test.cc - serialization/deserialization tests for kafka structures - - request_codec_request_integration_test.cc - integration test for all request operations using the codec API + - requests_test.cc - serialization/deserialization tests for kafka structures, + - request_codec_request_integration_test.cc - integration test for all request operations using the codec API. Templates used are: - - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2 - - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2 - - to create 'requests_test.cc': requests_test_cc.j2 - - to create 'request_codec_request_integration_test.cc' - request_codec_request_integration_test_cc.j2 + - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2, + - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2, + - to create 'requests_test.cc': requests_test_cc.j2, + - to create 'request_codec_request_integration_test.cc' - request_codec_request_integration_test_cc.j2. """ import sys @@ -52,7 +52,7 @@ def main(): requests = [] - # for each request spec file, remove comments, and parse the remains + # For each request specification file, remove comments, and parse the remains. for input_file in input_files: with open(input_file, 'r') as fd: raw_contents = fd.read() @@ -61,10 +61,10 @@ def main(): request = parse_request(request_spec) requests.append(request) - # sort requests by api_key + # Sort requests by api_key. requests.sort(key=lambda x: x.get_extra('api_key')) - # main source code + # Generate main source code. if 'generate-source' == command: complex_type_template = RenderingHelper.get_template('complex_type_template.j2') request_parsers_template = RenderingHelper.get_template('request_parser.j2') @@ -72,13 +72,13 @@ def main(): requests_h_contents = '' for request in requests: - # for each structure that is used by request, render its corresponding structures + # For each child structure that is used by request, render its corresponding C++ code. for dependency in request.declaration_chain: requests_h_contents += complex_type_template.render(complex_type=dependency) - # each top-level structure (e.g. FetchRequest) is going to have corresponding parsers + # Each top-level structure (e.g. FetchRequest) is going to have corresponding parsers. requests_h_contents += request_parsers_template.render(complex_type=request) - # full file with headers, namespace declaration etc. + # Full file with headers, namespace declaration etc. template = RenderingHelper.get_template('requests_h.j2') contents = template.render(contents=requests_h_contents) @@ -91,7 +91,7 @@ def main(): with open(kafka_request_resolver_cc_file, 'w') as fd: fd.write(contents) - # test code + # Generate test code. if 'generate-test' == command: template = RenderingHelper.get_template('requests_test_cc.j2') contents = template.render(request_types=requests) @@ -108,8 +108,9 @@ def main(): def parse_request(spec): """ - Parse a given structure into a request - Request is just a complex type, that has name & versions kept in differently named fields + Parse a given structure into a request. + Request is just a complex type, that has name & version information kept in differently named fields, compared to + sub-structures in a request. """ request_type_name = spec['name'] request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) @@ -119,7 +120,7 @@ def parse_request(spec): def parse_complex_type(type_name, field_spec, versions): """ - Parse given complex type, returning a structure that holds its name, field specification and allowed versions + Parse given complex type, returning a structure that holds its name, field specification and allowed versions. """ fields = [] for child_field in field_spec['fields']: @@ -130,8 +131,8 @@ def parse_complex_type(type_name, field_spec, versions): def parse_field(field_spec, highest_possible_version): """ - Parse given field, returning a structure holding the name, type, and versions when this field is actually used (nullable or not) - Obviously, field cannot be used in version higher than its type's usage + Parse given field, returning a structure holding the name, type, and versions when this field is actually used + (nullable or not). Obviously, field cannot be used in version higher than its type's usage. """ version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) version_usage_as_nullable = Statics.parse_version_string( @@ -143,10 +144,10 @@ def parse_field(field_spec, highest_possible_version): def parse_type(type_name, field_spec, highest_possible_version): """ - Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one (== struct) + Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one (== struct). """ if (type_name.startswith('[]')): - # in spec files, array types are defined as `[]underlying_type` instead of having its own element with type inside :\ + # In spec files, array types are defined as `[]underlying_type` instead of having its own element with type inside. underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) return Array(underlying_type) else: @@ -162,7 +163,7 @@ class Statics: @staticmethod def parse_version_string(raw_versions, highest_possible_version): """ - Return integer range that corresponds to version string in spec file + Return integer range that corresponds to version string in spec file. """ if raw_versions.endswith('+'): return range(int(raw_versions[:-1]), highest_possible_version + 1) @@ -177,8 +178,8 @@ def parse_version_string(raw_versions, highest_possible_version): class FieldList: """ - List of fields used by given entity (request or child structure) in given request version - (as fields get added/removed across versions) + List of fields used by given entity (request or child structure) in given request version (as fields get added + or removed across versions). """ def __init__(self, version, fields): @@ -187,41 +188,41 @@ def __init__(self, version, fields): def used_fields(self): """ - Return list of fields that are actually used in this version of structure + Return list of fields that are actually used in this version of structure. """ return filter(lambda x: x.used_in_version(self.version), self.fields) def constructor_signature(self): """ - Return constructor signature - Multiple versions of the same structure can have identical signatures (due to version bumps in Kafka) + Return constructor signature. + Multiple versions of the same structure can have identical signatures (due to version bumps in Kafka). """ parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) return ', '.join(parameter_spec) def constructor_init_list(self): """ - Renders member initialization list in constructor - Takes care of potential optional conversions (as field could be T in V1, but optional in V2) + Renders member initialization list in constructor. + Takes care of potential optional conversions (as field could be T in V1, but optional in V2). """ init_list = [] for field in self.fields: if field.used_in_version(self.version): if field.is_nullable(): if field.is_nullable_in_version(self.version): - # field is optional, and the parameter is optional in this version + # Field is optional, and the parameter is optional in this version. init_list_item = '%s_{%s}' % (field.name, field.name) init_list.append(init_list_item) else: - # field is optional, and the parameter is T in this version + # Field is optional, and the parameter is T in this version. init_list_item = '%s_{absl::make_optional(%s)}' % (field.name, field.name) init_list.append(init_list_item) else: - # field is T, so parameter cannot be optional + # Field is T, so parameter cannot be optional. init_list_item = '%s_{%s}' % (field.name, field.name) init_list.append(init_list_item) else: - # field is not used in this version, so we need to put in default value + # Field is not used in this version, so we need to put in default value. init_list_item = '%s_{%s}' % (field.name, field.default_value()) init_list.append(init_list_item) pass @@ -236,8 +237,8 @@ def example_value(self): class FieldSpec: """ - Represents a field present in a structure (request, or child structure thereof) - Contains name, type, and versions when it is used (nullable or not) + Represents a field present in a structure (request, or child structure thereof). + Contains name, type, and versions when it is used (nullable or not). """ def __init__(self, name, type, version_usage, version_usage_as_nullable): @@ -253,8 +254,8 @@ def is_nullable(self): def is_nullable_in_version(self, version): """ - Whether thie field is nullable in given version - Fields can be non-nullable in earlier versions + Whether thie field is nullable in given version. + Fields can be non-nullable in earlier versions. See https://github.com/apache/kafka/tree/2.2.0-rc0/clients/src/main/resources/common/message#nullable-fields """ return version in self.version_usage_as_nullable @@ -301,13 +302,13 @@ class TypeSpecification: def deserializer_name_in_version(self, version): """ - Renders the deserializer name of given type, in request with given version + Renders the deserializer name of given type, in request with given version. """ raise NotImplementedError() def default_value(self): """ - Returns a default value for given type + Returns a default value for given type. """ raise NotImplementedError() @@ -320,9 +321,9 @@ def is_printable(self): class Array(TypeSpecification): """ - Represents array complex type + Represents array complex type. To use instance of this type, it is necessary to declare structures required by self.underlying - (e.g. to use Array, we need to have `struct Foo {...}`) + (e.g. to use Array, we need to have `struct Foo {...}`). """ def __init__(self, underlying): @@ -350,7 +351,7 @@ def is_printable(self): class Primitive(TypeSpecification): """ - Represents a Kafka primitive value + Represents a Kafka primitive value. """ PRIMITIVE_TYPE_NAMES = ['bool', 'int8', 'int16', 'int32', 'int64', 'string', 'bytes'] @@ -375,7 +376,7 @@ class Primitive(TypeSpecification): 'bytes': 'BytesDeserializer', } - # https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages + # See https://github.com/apache/kafka/tree/trunk/clients/src/main/resources/common/message#deserializing-messages KAFKA_TYPE_TO_DEFAULT_VALUE = { 'string': '""', 'bool': 'false', @@ -386,7 +387,7 @@ class Primitive(TypeSpecification): 'bytes': '{}', } - # to make test code more readable + # Custom values that make test code more readable. KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { 'string': '"string"', 'bool': 'false', @@ -429,8 +430,8 @@ def is_printable(self): class Complex(TypeSpecification): """ - Represents a complex type (multiple types aggregated into one) - This type gets mapped to C++ struct + Represents a complex type (multiple types aggregated into one). + This type gets mapped to a C++ struct. """ def __init__(self, name, fields, versions): @@ -442,8 +443,8 @@ def __init__(self, name, fields, versions): def __compute_declaration_chain(self): """ - Computes all dependendencies, what means all non-primitive types used by this type - They need to be declared before this struct is declared + Computes all dependendencies, what means all non-primitive types used by this type. + They need to be declared before this struct is declared. """ result = [] for field in self.fields: @@ -460,8 +461,8 @@ def get_extra(self, key): def compute_constructors(self): """ - Field lists for different versions may not differ (as Kafka can bump version without any changes) - But constructors need to be unique, so we need to remove duplicates if the signatures match + Field lists for different versions may not differ (as Kafka can bump version without any changes). + But constructors need to be unique, so we need to remove duplicates if the signatures match. """ signature_to_constructor = {} for field_list in self.compute_field_lists(): @@ -483,7 +484,7 @@ def compute_constructors(self): def compute_field_lists(self): """ - Return field lists representing each of structure versions + Return field lists representing each of structure versions. """ field_lists = [] for version in self.versions: @@ -508,7 +509,7 @@ def is_printable(self): class RenderingHelper: """ - Helper for jinja templates + Helper for jinja templates. """ @staticmethod diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index 6ec0bfd7c6f6d..553a761945b46 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -1,6 +1,6 @@ {# - Template for 'kafka_request_resolver.cc' - Defines default Kafka request resolver, that uses request parsers in (also generated) 'requests.h' + Template for 'kafka_request_resolver.cc'. + Defines default Kafka request resolver, that uses request parsers in (also generated) 'requests.h'. #} #include "extensions/filters/network/kafka/requests.h" #include "extensions/filters/network/kafka/kafka_request_parser.h" @@ -12,8 +12,9 @@ namespace NetworkFilters { namespace Kafka { /** - * Creates a parser that corresponds to provided key and version - * If corresponding parser cannot be found (what means a newer version of Kafka protocol), a sentinel parser is returned + * Creates a parser that corresponds to provided key and version. + * If corresponding parser cannot be found (what means a newer version of Kafka protocol), a sentinel parser is + * returned. * @param api_key Kafka request key * @param api_version Kafka request's version * @param context parse context diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 index 0ac7f4a9c5418..65c0d2b475f1d 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 @@ -1,9 +1,13 @@ {# - Template for 'request_codec_request_integration_test.cc' - - Provides integration tests using Kafka codec - The only thing happening in these tests is creation of messages, passing them to codec, - and verifying that received parsed values are the same as data send + Template for 'request_codec_request_integration_test.cc'. + + Provides integration tests using Kafka codec. + The tests do the following: + - create the message, + - serialize the message into buffer, + - pass the buffer to the codec, + - capture messages received in callback, + - verify that captured messages are identical to the ones sent. #} #include "extensions/filters/network/kafka/request_codec.h" #include "extensions/filters/network/kafka/requests.h" @@ -19,78 +23,78 @@ namespace Kafka { class RequestCodecIntegrationTest : public testing::Test { protected: - template void putInBuffer(T arg); + template void putInBuffer(T arg); - Buffer::OwnedImpl buffer_; + Buffer::OwnedImpl buffer_; }; class CapturingRequestCallback : public RequestCallback { public: - virtual void onMessage(MessageSharedPtr request) override; + virtual void onMessage(MessageSharedPtr request) override; - const std::vector& getCaptured() const; + const std::vector& getCaptured() const; private: - std::vector captured_; + std::vector captured_; }; typedef std::shared_ptr CapturingRequestCallbackSharedPtr; void CapturingRequestCallback::onMessage(MessageSharedPtr message) { - captured_.push_back(message); + captured_.push_back(message); } const std::vector& CapturingRequestCallback::getCaptured() const { - return captured_; + return captured_; } {% for request_type in request_types %} -// integration test for {{ request_type.name }} messages +// Integration test for {{ request_type.name }} messages. TEST_F(RequestCodecIntegrationTest, shouldHandle{{ request_type.name }}Messages) { - // given - using Request = ConcreteRequest<{{ request_type.name }}>; - - std::vector sent; - int32_t correlation_id = 0; - - {% for field_list in request_type.compute_field_lists() %} - for (int i = 0; i < 100; ++i ) { - const RequestHeader header = { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation_id++, "client-id" }; - const {{ request_type.name }} data = { {{ field_list.example_value() }} }; - const Request request = {header, data}; - putInBuffer(request); - sent.push_back(request); - } - {% endfor %} - - const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); - const RequestParserResolver& request_parser_resolver = RequestParserResolver::getDefaultInstance(); - const CapturingRequestCallbackSharedPtr request_callback = std::make_shared(); - - RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; - - // when - testee.onData(buffer_); - - // then - const std::vector& received = request_callback->getCaptured(); - ASSERT_EQ(received.size(), sent.size()); - - for (size_t i = 0; i < received.size(); ++i) { - const std::shared_ptr request = std::dynamic_pointer_cast(received[i]); - ASSERT_NE(request, nullptr); - ASSERT_EQ(*request, sent[i]); - } + // given + using Request = ConcreteRequest<{{ request_type.name }}>; + + std::vector sent; + int32_t correlation_id = 0; + + {% for field_list in request_type.compute_field_lists() %} + for (int i = 0; i < 100; ++i ) { + const RequestHeader header = + { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation_id++, "client-id" }; + const {{ request_type.name }} data = { {{ field_list.example_value() }} }; + const Request request = {header, data}; + putInBuffer(request); + sent.push_back(request); + } + {% endfor %} + + const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); + const RequestParserResolver& request_parser_resolver = RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = std::make_shared(); + + RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; + + // when + testee.onData(buffer_); + + // then + const std::vector& received = request_callback->getCaptured(); + ASSERT_EQ(received.size(), sent.size()); + + for (size_t i = 0; i < received.size(); ++i) { + const std::shared_ptr request = std::dynamic_pointer_cast(received[i]); + ASSERT_NE(request, nullptr); + ASSERT_EQ(*request, sent[i]); + } } {% endfor %} -// misc utilities template void RequestCodecIntegrationTest::putInBuffer(const T arg) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(arg); + MessageEncoderImpl serializer{buffer_}; + serializer.encode(arg); } } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 index 0708d19402ab4..b01f52d2eae36 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -1,14 +1,16 @@ {# - Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest, FetchRequest, ListOffsetsRequest etc.) - Rendered templates for each request in Kafka protocol will be put into 'requests.h' file + Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest, FetchRequest etc.). + Rendered templates for each request in Kafka protocol will be put into 'requests.h' file. - This template handles binding the top-level structure deserializer (e.g. ProduceRequestV0Deserializer) with RequestParser - These parsers are then used by RequestParserResolver instance depending on received Kafka api key & api version (see 'kafka_request_resolver_cc.j2') + This template handles binding the top-level structure deserializer (e.g. ProduceRequestV0Deserializer) with + RequestParser. These parsers are then used by RequestParserResolver instance depending on received Kafka api key & + api version (see 'kafka_request_resolver_cc.j2'). #} -{% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { +{% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: + public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { public: - {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; + {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; }; {% endfor %} \ No newline at end of file diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 index 8dc803f49bfc2..66df9cf56abcf 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 @@ -1,24 +1,25 @@ {# - Main template for 'requests.h' file - Gets filled in (by 'contents') with Kafka request structures, deserializers, and parsers + Main template for 'requests.h' file. + Gets filled in (by 'contents') with Kafka request structures, deserializers, and parsers. - For each request we have the following: - - 1 top-level structure corresponding to the request (e.g. `struct FetchRequest`) - - N deserializers for top-level structure, one for each request version - - N parsers binding each deserializer with parser - - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that compose into top-level structure - - deserializers for each child structure (M = number of versions where structure is actually used) + For each request we have the following: + - 1 top-level structure corresponding to the request (e.g. `struct FetchRequest`), + - N deserializers for top-level structure, one for each request version, + - N parsers binding each deserializer with parser, + - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that compose into top-level + structure, + - deserializers for each child structure (M = number of versions where structure is actually used). - So for example, for FetchRequest we have: - - struct FetchRequest - - FetchRequestV0Deserializer, FetchRequestV1Deserializer, FetchRequestV2Deserializer, etc. - - FetchRequestV0Parser, FetchRequestV1Parser, FetchRequestV2Parser, etc. - - struct FetchRequestTopic - - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, FetchRequestTopicV2Deserializer, etc. - (because topic data is present in every FetchRequest version) - - struct FetchRequestPartition - - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, FetchRequestPartitionV2Deserializer, etc. - (because partition data is present in every FetchRequestTopic version) + So for example, for FetchRequest we have: + - struct FetchRequest, + - FetchRequestV0Deserializer, FetchRequestV1Deserializer, FetchRequestV2Deserializer, etc., + - FetchRequestV0Parser, FetchRequestV1Parser, FetchRequestV2Parser, etc., + - struct FetchRequestTopic, + - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, FetchRequestTopicV2Deserializer, etc. + (because topic data is present in every FetchRequest version), + - struct FetchRequestPartition, + - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, FetchRequestPartitionV2Deserializer, etc. + (because partition data is present in every FetchRequestTopic version). #} #pragma once #include "extensions/filters/network/kafka/kafka_request.h" diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 3a6d82e823245..8d164bdf1fea5 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -1,6 +1,6 @@ {# - Template for request serialization/deserialization tests - For every request, we want to check if it can be serialized and deserialized properly + Template for request serialization/deserialization tests. + For every request, we want to check if it can be serialized and deserialized properly. #} #include "extensions/filters/network/kafka/requests.h" @@ -29,9 +29,9 @@ public: }; /** - * Helper method - * Takes an instance of a request, serializes it, then deserializes it - * This method gets executed for every request * version pair + * Helper method. + * Takes an instance of a request, serializes it, then deserializes it. + * This method gets executed for every request * version pair. */ template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { MessageEncoderImpl serializer{buffer_}; @@ -49,14 +49,15 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria }; {# - Concrete tests for each request_type and version (field_list) - Each request is naively constructed using some default values (put "string" as std::string, 32 as uint32_t, etc.) + Concrete tests for each request_type and version (field_list). + Each request is naively constructed using some default values (put "string" as std::string, 32 as uint32_t, etc.). #} {% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} TEST_F(RequestDecoderTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { // given {{ request_type.name }} data = { {{ field_list.example_value() }} }; - ConcreteRequest<{{ request_type.name }}> request = { { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, 0, absl::nullopt }, data }; + ConcreteRequest<{{ request_type.name }}> request = { { + {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, 0, absl::nullopt }, data }; // when auto received = serializeAndDeserialize(request); diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index cfc577bc3fb2e..49cf6c1afabb9 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -11,7 +11,6 @@ namespace NetworkFilters { namespace Kafka { class RequestStartParserFactory : public InitialParserFactory { - ParserSharedPtr create(const RequestParserResolver& parser_resolver) const override { return std::make_shared(parser_resolver); } @@ -21,8 +20,8 @@ const InitialParserFactory& InitialParserFactory::getDefaultInstance() { CONSTRUCT_ON_FIRST_USE(RequestStartParserFactory); } -// convert buffer to slices and pass them to `doParse` void RequestDecoder::onData(Buffer::Instance& data) { + // Convert buffer to slices and pass them to `doParse`. uint64_t num_slices = data.getRawSlices(nullptr, 0); STACK_ARRAY(slices, Buffer::RawSlice, num_slices); data.getRawSlices(slices.begin(), num_slices); @@ -33,13 +32,14 @@ void RequestDecoder::onData(Buffer::Instance& data) { /** * Main parse loop: - * - forward data to current parser + * - forward data to current parser, * - receive parser response: - * -- if still waiting, do nothing - * -- if next parser, replace current parser, and keep feeding, if still have data - * -- if parser message: - * --- notify callbacks - * --- replace current parser with new start parser, as we are going to parse another request + * -- if still waiting, do nothing (we wait for more data), + * -- if a parser is given, replace current parser with the new one, and it the rest of the data + * -- if a message is given: + * --- notify callbacks, + * --- replace current parser with new start parser, as we are going to start parsing the next + * message. */ void RequestDecoder::doParse(const Buffer::RawSlice& slice) { const char* bytes = reinterpret_cast(slice.mem_); @@ -47,27 +47,27 @@ void RequestDecoder::doParse(const Buffer::RawSlice& slice) { while (!data.empty()) { - // feed the data to the parser + // Feed the data to the parser. ParseResponse result = current_parser_->parse(data); - // this loop guarantees that parsers consuming 0 bytes also get processed in this invocation + // This loop guarantees that parsers consuming 0 bytes also get processed in this invocation. while (result.hasData()) { if (!result.next_parser_) { - // next parser is not present, so we have finished parsing a message + // Next parser is not present, so we have finished parsing a message. MessageSharedPtr message = result.message_; for (auto& callback : callbacks_) { callback->onMessage(result.message_); } - // as we finished parsing this request, re-initialize the parser + // As we finished parsing this request, re-initialize the parser. current_parser_ = factory_.create(parser_resolver_); } else { - // the next parser that's supposed to consume the rest of payload was given + // The next parser that's supposed to consume the rest of payload was given. current_parser_ = result.next_parser_; } - // keep parsing the data + // Keep parsing the data. result = current_parser_->parse(data); } } @@ -75,13 +75,13 @@ void RequestDecoder::doParse(const Buffer::RawSlice& slice) { void MessageEncoderImpl::encode(const Message& message) { Buffer::OwnedImpl data_buffer; - // TODO(adamkotwasinski) precompute the size instead of using temporary - // also, when we have 'computeSize' method, then we can push encoding request's size into + // TODO(adamkotwasinski) Precompute the size instead of using temporary buffer. + // When we have the 'computeSize' method, then we can push encoding request's size into // Request::encode - int32_t data_len = message.encode(data_buffer); // encode data computing data length + int32_t data_len = message.encode(data_buffer); // Encode data and compute data length. EncodingContext encoder{-1}; - encoder.encode(data_len, output_); // encode data length into result - output_.add(data_buffer); // copy data into result + encoder.encode(data_len, output_); // Encode data length into result. + output_.add(data_buffer); // Copy encoded data into result. } } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index 0150355cc7f6b..7ddbb2f1417fd 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -14,14 +14,14 @@ namespace NetworkFilters { namespace Kafka { /** - * Callback invoked when request is successfully decoded + * Callback invoked when request is successfully decoded. */ class RequestCallback { public: virtual ~RequestCallback() = default; /** - * Callback method invoked when request is successfully decoded + * Callback method invoked when request is successfully decoded. * @param request request that has been decoded */ virtual void onMessage(MessageSharedPtr request) PURE; diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index c1d35fe357ab1..661c84d03da58 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -22,11 +22,11 @@ namespace NetworkFilters { namespace Kafka { /** - * Deserializer is a stateful entity that constructs a result of type T from bytes provided - * It can be feed()-ed data until it is ready, filling the internal store - * When ready(), it is safe to call get() to transform the internally stored bytes into result + * Deserializer is a stateful entity that constructs a result of type T from bytes provided. + * It can be feed()-ed data until it is ready, filling the internal store. + * When ready(), it is safe to call get() to transform the internally stored bytes into result. * Further feed()-ing should have no effect on a buffer (should return 0 and not move - * buffer/remaining) + * provided pointer). * @param T type of deserialized data */ template class Deserializer { @@ -36,29 +36,27 @@ template class Deserializer { /** * Submit data to be processed, will consume as much data as it is necessary. * If any bytes are consumed, then the provided string view is updated by stepping over consumed - * bytes. - * Invoking this method when deserializer is ready has no effect (consumes 0 bytes) - * + * bytes. Invoking this method when deserializer is ready has no effect (consumes 0 bytes). * @param data bytes to be processed, will be updated if any have been consumed * @return number of bytes consumed (equal to change in 'data') */ virtual size_t feed(absl::string_view& data) PURE; /** - * Whether deserializer has consumed enough data to return result + * Whether deserializer has consumed enough data to return result. */ virtual bool ready() const PURE; /** - * Returns the entity that is represented by bytes stored in this deserializer - * Should be only called when deserializer is ready + * Returns the entity that is represented by bytes stored in this deserializer. + * Should be only called when deserializer is ready. */ virtual T get() const PURE; }; /** - * Generic integer deserializer (uses array of sizeof(T) bytes) - * After all bytes are filled in, the value is converted from network byte-order and returned + * Generic integer deserializer (uses array of sizeof(T) bytes). + * After all bytes are filled in, the value is converted from network byte-order and returned. */ template class IntDeserializer : public Deserializer { public: @@ -87,7 +85,7 @@ template class IntDeserializer : public Deserializer { }; /** - * Integer deserializer for int8_t + * Integer deserializer for int8_t. */ class Int8Deserializer : public IntDeserializer { public: @@ -99,7 +97,7 @@ class Int8Deserializer : public IntDeserializer { }; /** - * Integer deserializer for int16_t + * Integer deserializer for int16_t. */ class Int16Deserializer : public IntDeserializer { public: @@ -111,7 +109,7 @@ class Int16Deserializer : public IntDeserializer { }; /** - * Integer deserializer for int32_t + * Integer deserializer for int32_t. */ class Int32Deserializer : public IntDeserializer { public: @@ -123,7 +121,7 @@ class Int32Deserializer : public IntDeserializer { }; /** - * Integer deserializer for uint32_t + * Integer deserializer for uint32_t. */ class UInt32Deserializer : public IntDeserializer { public: @@ -135,7 +133,7 @@ class UInt32Deserializer : public IntDeserializer { }; /** - * Integer deserializer for uint64_t + * Integer deserializer for uint64_t. */ class Int64Deserializer : public IntDeserializer { public: @@ -148,13 +146,10 @@ class Int64Deserializer : public IntDeserializer { /** * Deserializer for boolean values - * Uses a single int8 deserializers, and just checks != 0 - * impl note: could have been a subclass of IntDeserializer with a different get function, - * but it makes it harder to understand - * - * Boolean value is stored in a byte. - * Values 0 and 1 are used to represent false and true respectively. + * Uses a single int8 deserializer, and checks whether the results equals 0. * When reading a boolean value, any non-zero value is considered true. + * Impl note: could have been a subclass of IntDeserializer with a different get function, + * but it makes it harder to understand. */ class BooleanDeserializer : public Deserializer { public: @@ -171,8 +166,8 @@ class BooleanDeserializer : public Deserializer { }; /** - * Deserializer of string value - * First reads length (INT16) and then allocates the buffer of given length + * Deserializer of string value. + * First reads length (INT16) and then allocates the buffer of given length. * * From documentation: * First the length N is given as an INT16. @@ -182,12 +177,12 @@ class BooleanDeserializer : public Deserializer { class StringDeserializer : public Deserializer { public: /** - * Can throw EnvoyException if given string length is not valid + * Can throw EnvoyException if given string length is not valid. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -230,10 +225,10 @@ class StringDeserializer : public Deserializer { }; /** - * Deserializer of nullable string value - * First reads length (INT16) and then allocates the buffer of given length + * Deserializer of nullable string value. + * First reads length (INT16) and then allocates the buffer of given length. * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning - * null value) + * null value). * * From documentation: * For non-null strings, first the length N is given as an INT16. @@ -243,12 +238,12 @@ class StringDeserializer : public Deserializer { class NullableStringDeserializer : public Deserializer { public: /** - * Can throw EnvoyException if given string length is not valid + * Can throw EnvoyException if given string length is not valid. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -306,8 +301,8 @@ class NullableStringDeserializer : public Deserializer { }; /** - * Deserializer of bytes value - * First reads length (INT32) and then allocates the buffer of given length + * Deserializer of bytes value. + * First reads length (INT32) and then allocates the buffer of given length. * * From documentation: * First the length N is given as an INT32. Then N bytes follow. @@ -315,12 +310,12 @@ class NullableStringDeserializer : public Deserializer { class BytesDeserializer : public Deserializer { public: /** - * Can throw EnvoyException if given bytes length is not valid + * Can throw EnvoyException if given bytes length is not valid. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -362,10 +357,10 @@ class BytesDeserializer : public Deserializer { }; /** - * Deserializer of nullable bytes value - * First reads length (INT32) and then allocates the buffer of given length + * Deserializer of nullable bytes value. + * First reads length (INT32) and then allocates the buffer of given length. * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning - * null value) + * null value). * * From documentation: * For non-null values, first the length N is given as an INT32. Then N bytes follow. @@ -374,12 +369,12 @@ class BytesDeserializer : public Deserializer { class NullableBytesDeserializer : public Deserializer { public: /** - * Can throw EnvoyException if given bytes length is not valid + * Can throw EnvoyException if given bytes length is not valid. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -439,11 +434,11 @@ class NullableBytesDeserializer : public Deserializer { }; /** - * Deserializer for array of objects of the same type + * Deserializer for array of objects of the same type. * * First reads the length of the array, then initializes N underlying deserializers of type - * DeserializerType After the last of N deserializers is ready, the results of each of them are - * gathered and put in a vector + * DeserializerType. After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector. * @param ResponseType result type returned by deserializer of type DeserializerType * @param DeserializerType underlying deserializer type * @@ -456,13 +451,13 @@ template class ArrayDeserializer : public Deserializer> { public: /** - * Can throw EnvoyException if array length is invalid or if DeserializerType can throw + * Can throw EnvoyException if array length is invalid or if underlying deserializer can throw. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -516,11 +511,11 @@ class ArrayDeserializer : public Deserializer> { }; /** - * Deserializer for nullable array of objects of the same type + * Deserializer for nullable array of objects of the same type. * * First reads the length of the array, then initializes N underlying deserializers of type - * DeserializerType After the last of N deserializers is ready, the results of each of them are - * gathered and put in a vector + * DeserializerType. After the last of N deserializers is ready, the results of each of them are + * gathered and put in a vector. * @param ResponseType result type returned by deserializer of type DeserializerType * @param DeserializerType underlying deserializer type * @@ -533,13 +528,13 @@ template class NullableArrayDeserializer : public Deserializer> { public: /** - * Can throw EnvoyException if array length is invalid or if DeserializerType can throw + * Can throw EnvoyException if array length is invalid or if underlying deserializer can throw. */ size_t feed(absl::string_view& data) override { const size_t length_consumed = length_buf_.feed(data); if (!length_buf_.ready()) { - // break early: we still need to fill in length buffer + // Break early: we still need to fill in length buffer. return length_consumed; } @@ -605,34 +600,33 @@ class NullableArrayDeserializer : public Deserializer size_t encode(const T& arg, Buffer::Instance& dst); /** - * Encode given array in a buffer + * Encode given array in a buffer. * @return bytes written */ template size_t encode(const std::vector& arg, Buffer::Instance& dst); /** - * Encode given nullable array in a buffer + * Encode given nullable array in a buffer. * @return bytes written */ template size_t encode(const NullableArray& arg, Buffer::Instance& dst); @@ -645,15 +639,15 @@ class EncodingContext { /** * For non-primitive types, call `encode` on them, to delegate the serialization to the entity - * itself + * itself. */ template inline size_t EncodingContext::encode(const T& arg, Buffer::Instance& dst) { return arg.encode(dst, *this); } /** - * Template overload for int8_t - * Encode a single byte + * Template overload for int8_t. + * Encode a single byte. */ template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Instance& dst) { dst.add(&arg, sizeof(int8_t)); @@ -661,8 +655,8 @@ template <> inline size_t EncodingContext::encode(const int8_t& arg, Buffer::Ins } /** - * Template overload for int16_t, int32_t, uint32_t, int64_t - * Encode a N-byte integer, converting to network byte-order + * Template overload for int16_t, int32_t, uint32_t, int64_t. + * Encode a N-byte integer, converting to network byte-order. */ #define ENCODE_NUMERIC_TYPE(TYPE, CONVERTER) \ template <> inline size_t EncodingContext::encode(const TYPE& arg, Buffer::Instance& dst) { \ @@ -677,8 +671,8 @@ ENCODE_NUMERIC_TYPE(uint32_t, htobe32); ENCODE_NUMERIC_TYPE(int64_t, htobe64); /** - * Template overload for bool - * Encode boolean as a single byte + * Template overload for bool. + * Encode boolean as a single byte. */ template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Instance& dst) { int8_t val = arg; @@ -687,8 +681,8 @@ template <> inline size_t EncodingContext::encode(const bool& arg, Buffer::Insta } /** - * Template overload for std::string - * Encode string as INT16 length + N bytes + * Template overload for std::string. + * Encode string as INT16 length + N bytes. */ template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer::Instance& dst) { int16_t string_length = arg.length(); @@ -698,8 +692,8 @@ template <> inline size_t EncodingContext::encode(const std::string& arg, Buffer } /** - * Template overload for NullableString - * Encode nullable string as INT16 length + N bytes (length = -1 for null) + * Template overload for NullableString. + * Encode nullable string as INT16 length + N bytes (length = -1 for null). */ template <> inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instance& dst) { @@ -712,8 +706,8 @@ inline size_t EncodingContext::encode(const NullableString& arg, Buffer::Instanc } /** - * Template overload for Bytes - * Encode byte array as INT32 length + N bytes + * Template overload for Bytes. + * Encode byte array as INT32 length + N bytes. */ template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Instance& dst) { const int32_t data_length = arg.size(); @@ -723,8 +717,8 @@ template <> inline size_t EncodingContext::encode(const Bytes& arg, Buffer::Inst } /** - * Template overload for NullableBytes - * Encode nullable byte array as INT32 length + N bytes (length = -1 for null) + * Template overload for NullableBytes. + * Encode nullable byte array as INT32 length + N bytes (length = -1 for null value). */ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buffer::Instance& dst) { if (arg.has_value()) { @@ -736,8 +730,8 @@ template <> inline size_t EncodingContext::encode(const NullableBytes& arg, Buff } /** - * Encode nullable object array to T as INT32 length + N elements - * Each element of type T then serializes itself on its own + * Encode nullable object array to T as INT32 length + N elements. + * Each element of type T then serializes itself on its own. */ template size_t EncodingContext::encode(const std::vector& arg, Buffer::Instance& dst) { @@ -746,8 +740,8 @@ size_t EncodingContext::encode(const std::vector& arg, Buffer::Instance& dst) } /** - * Encode nullable object array to T as INT32 length + N elements (length = -1 for null) - * Each element of type T then serializes itself on its own + * Encode nullable object array to T as INT32 length + N elements (length = -1 for null value). + * Each element of type T then serializes itself on its own. */ template size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& dst) { @@ -756,8 +750,8 @@ size_t EncodingContext::encode(const NullableArray& arg, Buffer::Instance& ds const size_t header_length = encode(len, dst); size_t written{0}; for (const T& el : *arg) { - // for each of array elements, resolve the correct method again - // elements could be primitives or complex types, so calling encode() on object won't work + // For each of array elements, resolve the correct method again. + // Elements could be primitives or complex types, so calling encode() on object won't work. written += encode(el, dst); } return header_length + written; diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py index 49dcd848114df..21478845a61c9 100755 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py @@ -5,25 +5,25 @@ def main(): """ Serialization composite generator script ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ - Generates main&test source code files for composite deserializers - The files are generated, as they are extremely repetitive (composite deserializer for 0..9 sub-deserializers) + Generates main&test source code files for composite deserializers. + The files are generated, as they are extremely repetitive (composite deserializer for 0..9 sub-deserializers). Usage: - serialization_composite_generator.py COMMAND LOCATION_OF_OUTPUT_FILE + serialization_composite_generator.py COMMAND LOCATION_OF_OUTPUT_FILE where: - COMMAND : 'generate-source', to generate source files - 'generate-test', to generate test files + COMMAND : 'generate-source', to generate source files, + 'generate-test', to generate test files. LOCATION_OF_OUTPUT_FILE : if generate-source: location of 'serialization_composite.h', - if generate-test: location of 'serialization_composite_test.cc' + if generate-test: location of 'serialization_composite_test.cc'. When generating source code, it creates: - - serialization_composite.h - header with declarations of CompositeDeserializerWith???Delegates classes + - serialization_composite.h - header with declarations of CompositeDeserializerWith???Delegates classes. When generating test code, it creates: - - serialization_composite_test.cc - tests for these classes + - serialization_composite_test.cc - tests for these classes. Templates used are: - - to create 'serialization_composite.h': serialization_composite_h.j2 - - to create 'serialization_composite_test.cc': serialization_composite_test_cc.j2 + - to create 'serialization_composite.h': serialization_composite_h.j2, + - to create 'serialization_composite_test.cc': serialization_composite_test_cc.j2. """ import sys @@ -40,17 +40,17 @@ def main(): import re import json - # number of fields deserialized by each deserializer + # Number of fields deserialized by each deserializer class. field_counts = range(1, 10) - # main source code + # Generate main source code. if 'generate-source' == command: template = RenderingHelper.get_template('serialization_composite_h.j2') contents = template.render(counts=field_counts) with open(serialization_composite_h_file, 'w') as fd: fd.write(contents) - # test code + # Generate test code. if 'generate-test' == command: template = RenderingHelper.get_template('serialization_composite_test_cc.j2') contents = template.render(counts=field_counts) @@ -60,7 +60,7 @@ def main(): class RenderingHelper: """ - Helper for jinja templates + Helper for jinja templates. """ @staticmethod diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 index 32a40cb9b86d4..44910f99161dd 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 @@ -1,8 +1,8 @@ {# - Creates 'serialization_composite.h' + Creates 'serialization_composite.h'. - Template for composite serializers (the CompositeDeserializerWith_N_Delegates classes) - Covers the corner case of 0 delegates, and then uses templating to create declarations for 1..N variants + Template for composite serializers (the CompositeDeserializerWith_N_Delegates classes). + Covers the corner case of 0 delegates, and then uses templating to create declarations for 1..N variants. #} #pragma once @@ -29,66 +29,65 @@ namespace NetworkFilters { namespace Kafka { /** - * This header contains only composite deserializers - * The basic design is composite deserializer creating delegates DeserializerType1..N - * Result of type ResponseType is constructed by getting results of each of delegates - * These deserializers can throw, if any of the delegate deserializers can + * This header contains only composite deserializers. + * The basic design is composite deserializer creating delegates DeserializerType1..N. + * Result of type ResponseType is constructed by getting results of each of delegates. + * These deserializers can throw, if any of the delegate deserializers can. */ /** - * Composite deserializer that uses 0 deserializer(s) (corner case) - * Always ready, as it consumes no bytes - * Creates a result value using a no-arg constructor + * Composite deserializer that uses 0 deserializer(s) (corner case). + * Does not consume any bytes, and is always ready to return the result. + * Creates a result value using the no-arg ResponseType constructor. * @param ResponseType type of deserialized data */ template class CompositeDeserializerWith0Delegates : public Deserializer { public: - CompositeDeserializerWith0Delegates(){}; - size_t feed(absl::string_view&) override { return 0; } - bool ready() const override { return true; } - ResponseType get() const override { return {}; } + CompositeDeserializerWith0Delegates(){}; + size_t feed(absl::string_view&) override { return 0; } + bool ready() const override { return true; } + ResponseType get() const override { return {}; } }; {% for field_count in counts %} /** - * Composite deserializer that uses {{ field_count }} deserializer(s) - * Passes data to each of the underlying deserializers - * (deserializers that are already ready do not consume data, so it's safe). - * The composite deserializer is ready when the last deserializer is ready - * (which means all deserializers before it are ready too) + * Composite deserializer that uses {{ field_count }} deserializer(s). + * Passes data to each of the underlying deserializers (deserializers that are already ready do not consume data, + * so it's safe). The composite deserializer is ready when the last deserializer is ready (what means that all + * deserializers before it are ready too). * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } * - * @param ResponseType type of deserialized data -{% for field in range(1, field_count + 1) %} * @param DeserializerType{{ field }} deserializer {{ field }} (result used as argument {{ field }} of ResponseType's ctor) + * @param ResponseType type of deserialized data{% for field in range(1, field_count + 1) %} + * @param DeserializerType{{ field }} deserializer {{ field }} {% endfor %} */ template < - typename ResponseType{% for field in range(1, field_count + 1) %}, typename DeserializerType{{ field }}{% endfor %} + typename ResponseType{% for field in range(1, field_count + 1) %}, typename DeserializerType{{ field }}{% endfor %} > class CompositeDeserializerWith{{ field_count }}Delegates : public Deserializer { public: - CompositeDeserializerWith{{ field_count }}Delegates(){}; + CompositeDeserializerWith{{ field_count }}Delegates(){}; - size_t feed(absl::string_view& data) override { - size_t consumed = 0; - {% for field in range(1, field_count + 1) %} - consumed += delegate{{ field }}_.feed(data); - {% endfor %} - return consumed; - } + size_t feed(absl::string_view& data) override { + size_t consumed = 0; + {% for field in range(1, field_count + 1) %} + consumed += delegate{{ field }}_.feed(data); + {% endfor %} + return consumed; + } - bool ready() const override { return delegate{{ field_count }}_.ready(); } + bool ready() const override { return delegate{{ field_count }}_.ready(); } - ResponseType get() const override { - return { - {% for field in range(1, field_count + 1) %}delegate{{ field }}_.get(), - {% endfor %}}; - } + ResponseType get() const override { + return { + {% for field in range(1, field_count + 1) %}delegate{{ field }}_.get(), + {% endfor %}}; + } protected: - {% for field in range(1, field_count + 1) %} - DeserializerType{{ field }} delegate{{ field }}_; - {% endfor %} + {% for field in range(1, field_count + 1) %} + DeserializerType{{ field }} delegate{{ field }}_; + {% endfor %} }; {% endfor %} diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 index 4708b508da37c..22533e3a78088 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 @@ -1,8 +1,8 @@ {# - Creates 'serialization_composite_test.cc' + Creates 'serialization_composite_test.cc'. - Template for composite serializer tests (the CompositeDeserializerWith_N_Delegates classes) - Covers the corner case of 0 delegates, and then uses templating to create tests for 1..N cases + Template for composite serializer tests (the CompositeDeserializerWith_N_Delegates classes). + Covers the corner case of 0 delegates, and then uses templating to create tests for 1..N cases. #} #include "extensions/filters/network/kafka/serialization_composite.h" @@ -15,69 +15,69 @@ namespace NetworkFilters { namespace Kafka { /** - * Tests in this class are supposed to check whether serialization operations on composite deserializers are behaving correctly + * Tests in this class are supposed to check whether serialization operations on composite deserializers are correct. */ -// tests for composite deserializer with 0 fields (corner case) +// Tests for composite deserializer with 0 fields (corner case). struct CompositeResultWith0Fields { - size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } - bool operator==(const CompositeResultWith0Fields&) const { return true; } + size_t encode(Buffer::Instance&, EncodingContext&) const { return 0; } + bool operator==(const CompositeResultWith0Fields&) const { return true; } }; typedef CompositeDeserializerWith0Delegates TestCompositeDeserializer0; -// composite with 0 delegates is special case: it's always ready +// Composite with 0 delegates is special case: it's always ready. TEST(CompositeDeserializerWith0Delegates, EmptyBufferShouldBeReady) { - // given - const TestCompositeDeserializer0 testee{}; - // when, then - ASSERT_EQ(testee.ready(), true); + // given + const TestCompositeDeserializer0 testee{}; + // when, then + ASSERT_EQ(testee.ready(), true); } TEST(CompositeDeserializerWith0Delegates, ShouldDeserialize) { - const CompositeResultWith0Fields expected{}; - serializeThenDeserializeAndCheckEquality(expected); + const CompositeResultWith0Fields expected{}; + serializeThenDeserializeAndCheckEquality(expected); } -// tests for composite deserializer with N+ fields +// Tests for composite deserializer with N+ fields. {% for field_count in counts %} struct CompositeResultWith{{ field_count }}Fields { - {% for field in range(1, field_count + 1) %} - const std::string field{{ field }}_; - {% endfor %} - - size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { - size_t written{0}; - {% for field in range(1, field_count + 1) %} - written += encoder.encode(field{{ field }}_, dst); - {% endfor %} - return written; - } - - bool operator==(const CompositeResultWith{{ field_count }}Fields& rhs) const { - return true{% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_{% endfor %}; - } + {% for field in range(1, field_count + 1) %} + const std::string field{{ field }}_; + {% endfor %} + + size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { + size_t written{0}; + {% for field in range(1, field_count + 1) %} + written += encoder.encode(field{{ field }}_, dst); + {% endfor %} + return written; + } + + bool operator==(const CompositeResultWith{{ field_count }}Fields& rhs) const { + return true{% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_{% endfor %}; + } }; typedef CompositeDeserializerWith{{ field_count }}Delegates< - CompositeResultWith{{ field_count }}Fields - {% for field in range(1, field_count + 1) %}, StringDeserializer{% endfor %} + CompositeResultWith{{ field_count }}Fields + {% for field in range(1, field_count + 1) %}, StringDeserializer{% endfor %} > TestCompositeDeserializer{{ field_count }}; TEST(CompositeDeserializerWith{{ field_count }}Delegates, EmptyBufferShouldNotBeReady) { - // given - const TestCompositeDeserializer{{ field_count }} testee{}; - // when, then - ASSERT_EQ(testee.ready(), false); + // given + const TestCompositeDeserializer{{ field_count }} testee{}; + // when, then + ASSERT_EQ(testee.ready(), false); } TEST(CompositeDeserializerWith{{ field_count }}Delegates, ShouldDeserialize) { - const CompositeResultWith{{ field_count }}Fields expected{ - {% for field in range(1, field_count + 1) %}"s{{ field }}", {% endfor %} - }; - serializeThenDeserializeAndCheckEquality(expected); + const CompositeResultWith{{ field_count }}Fields expected{ + {% for field in range(1, field_count + 1) %}"s{{ field }}", {% endfor %} + }; + serializeThenDeserializeAndCheckEquality(expected); } {% endfor %} diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index 07a57856a052a..7c9343ad920bd 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -48,8 +48,8 @@ genrule( srcs = [], outs = ["serialization_composite_test.cc"], cmd = """ - ./$(location //source/extensions/filters/network/kafka:serialization_composite_generator) generate-test \ - $(location serialization_composite_test.cc) + ./$(location //source/extensions/filters/network/kafka:serialization_composite_generator) \ + generate-test $(location serialization_composite_test.cc) """, tools = [ "//source/extensions/filters/network/kafka:serialization_composite_generator", diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index 8f1b9365a0522..c367ea3328530 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -28,7 +28,7 @@ class BufferBasedTest : public testing::Test { protected: Buffer::OwnedImpl buffer_; - EncodingContext encoder_{-1}; // api_version is not used for request header + EncodingContext encoder_{-1}; // Context's api_version is not used when serializing request header. }; class MockRequestParserResolver : public RequestParserResolver { @@ -108,11 +108,11 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDuringFeeding) { // given - // throws during feeding + // This deserializer throws during feeding. class ThrowingRequestHeaderDeserializer : public RequestHeaderDeserializer { public: size_t feed(absl::string_view& data) override { - // move some pointers to simulate data consumption + // Move some pointers to simulate data consumption. data = {data.data() + FAILED_DESERIALIZER_STEP, data.size() - FAILED_DESERIALIZER_STEP}; throw EnvoyException("feed"); }; @@ -126,7 +126,7 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur const MockRequestParserResolver parser_resolver; - const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request + const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; RequestHeaderParser testee{parser_resolver, request_context, std::make_unique()}; @@ -150,11 +150,12 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { // given - // throws during feeding + + // This deserializer throws during feeding. class ThrowingDeserializer : public Deserializer { public: size_t feed(absl::string_view&) override { - // move some pointers to simulate data consumption + // Move some pointers to simulate data consumption. throw EnvoyException("feed"); }; @@ -180,7 +181,7 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFee ASSERT_EQ(caught, true); } -// deserializer that consumes FAILED_DESERIALIZER_STEP bytes and returns 0 +// This deserializer consumes FAILED_DESERIALIZER_STEP bytes and returns 0 class SomeBytesDeserializer : public Deserializer { public: size_t feed(absl::string_view& data) override { @@ -195,7 +196,7 @@ class SomeBytesDeserializer : public Deserializer { TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyButLeavingData) { // given - const int32_t request_size = 1024; // there are still 1024 bytes to read to complete the request + const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; RequestParser testee{request_context}; diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index fdd56b0f75199..58148c09f64cc 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -34,10 +34,11 @@ const std::vector& CapturingRequestCallback::getCaptured() con return captured_; } -// other request types are tested in (generated) 'request_codec_request_integration_test.cc' +// Other request types are tested in (generated) 'request_codec_request_integration_test.cc'. TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given - // api keys have values below 100, so the messages generated in this loop should not be recognized + // As real api keys have values below 100, the messages generated in this loop should not be + // recognized by the codec. const int16_t base_api_key = 100; std::vector sent_headers; for (int16_t i = 0; i < 1000; ++i) { @@ -71,7 +72,7 @@ TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { } } -// misc utilities +// Helper function. template void RequestDecoderTest::putInBuffer(T arg) { MessageEncoderImpl serializer{buffer_}; serializer.encode(arg); diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc index 0feb02fec4932..4906bbd56711e 100644 --- a/test/extensions/filters/network/kafka/request_codec_unit_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -72,7 +72,8 @@ TEST_F(RequestDecoderTest, shouldDoNothingIfParserNeverReturnsMessage) { // when testee.onData(buffer_); - // then - request_callback is not interacted with + // then + // There were no interactions with `request_callback`. } TEST_F(RequestDecoderTest, shouldUseNewParserAsResponse) { @@ -93,7 +94,8 @@ TEST_F(RequestDecoderTest, shouldUseNewParserAsResponse) { // when testee.onData(buffer_); - // then - request_callback is not interacted with + // then + // There were no interactions with `request_callback`. } TEST_F(RequestDecoderTest, shouldReturnParsedMessageAndReinitialize) { @@ -118,7 +120,8 @@ TEST_F(RequestDecoderTest, shouldReturnParsedMessageAndReinitialize) { // when testee.onData(buffer_); - // then - request_callback got notified only once + // then + // There was only one message sent to `request_callback`. } TEST_F(RequestDecoderTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { @@ -150,10 +153,12 @@ TEST_F(RequestDecoderTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { // when testee.onData(buffer_); - // then - parser3 was given only empty data (size 0) + // then + // There was only one message sent to `request_callback`. + // After that, `parser3` was created and passed remaining data (that should have been empty). } -// misc utilities +// Helper function. template void RequestDecoderTest::putInBuffer(T arg) { MessageEncoderImpl serializer{buffer_}; serializer.encode(arg); diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index 6eea87453ff2f..ff2e5de20e67f 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -7,10 +7,10 @@ namespace Kafka { /** * Tests in this file are supposed to check whether serialization operations - * on Kafka-primitive types (ints, strings, arrays) are behaving correctly + * on Kafka-primitive types (ints, strings, arrays) are behaving correctly. */ -// freshly created deserializers should not be ready +// Freshly created deserializers should not be ready. #define TEST_EmptyDeserializerShouldNotBeReady(DeserializerClass) \ TEST(DeserializerClass, EmptyBufferShouldNotBeReady) { \ const DeserializerClass testee{}; \ @@ -43,7 +43,7 @@ TEST(NullableArrayDeserializer, EmptyBufferShouldNotBeReady) { ASSERT_EQ(testee.ready(), false); } -// extracted test for numeric buffers +// Extracted test for numeric buffers. #define TEST_DeserializerShouldDeserialize(BufferClass, DataClass, Value) \ TEST(DataClass, ShouldConsumeCorrectAmountOfData) { \ /* given */ \ @@ -58,7 +58,7 @@ TEST_DeserializerShouldDeserialize(UInt32Deserializer, uint32_t, 42); TEST_DeserializerShouldDeserialize(Int64Deserializer, int64_t, 42); TEST_DeserializerShouldDeserialize(BooleanDeserializer, bool, true); -EncodingContext encoder{-1}; // api_version does not matter for primitive types +EncodingContext encoder{-1}; // Provided api_version does not matter for primitive types. TEST(StringDeserializer, ShouldDeserialize) { const std::string value = "sometext"; @@ -75,7 +75,7 @@ TEST(StringDeserializer, ShouldThrowOnInvalidLength) { StringDeserializer testee; Buffer::OwnedImpl buffer; - int16_t len = -1; // STRING accepts only >= 0 + int16_t len = -1; // STRING accepts length >= 0. encoder.encode(len, buffer); absl::string_view data = {getRawData(buffer), 1024}; @@ -108,7 +108,7 @@ TEST(NullableStringDeserializer, ShouldThrowOnInvalidLength) { NullableStringDeserializer testee; Buffer::OwnedImpl buffer; - int16_t len = -2; // -1 is OK for NULLABLE_STRING + int16_t len = -2; // -1 is OK for NULLABLE_STRING. encoder.encode(len, buffer); absl::string_view data = {getRawData(buffer), 1024}; @@ -133,7 +133,7 @@ TEST(BytesDeserializer, ShouldThrowOnInvalidLength) { BytesDeserializer testee; Buffer::OwnedImpl buffer; - const int32_t bytes_length = -1; // BYTES accepts only >= 0 + const int32_t bytes_length = -1; // BYTES accepts length >= 0. encoder.encode(bytes_length, buffer); absl::string_view data = {getRawData(buffer), 1024}; @@ -163,7 +163,7 @@ TEST(NullableBytesDeserializer, ShouldThrowOnInvalidLength) { NullableBytesDeserializer testee; Buffer::OwnedImpl buffer; - const int32_t bytes_length = -2; // -1 is OK for NULLABLE_BYTES + const int32_t bytes_length = -2; // -1 is OK for NULLABLE_BYTES. encoder.encode(bytes_length, buffer); absl::string_view data = {getRawData(buffer), 1024}; @@ -184,7 +184,7 @@ TEST(ArrayDeserializer, ShouldThrowOnInvalidLength) { ArrayDeserializer testee; Buffer::OwnedImpl buffer; - const int32_t len = -1; // ARRAY accepts only >= 0 + const int32_t len = -1; // ARRAY accepts length >= 0. encoder.encode(len, buffer); absl::string_view data = {getRawData(buffer), 1024}; @@ -205,7 +205,7 @@ TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { NullableArrayDeserializer testee; Buffer::OwnedImpl buffer; - const int32_t len = -2; // -1 is OK for ARRAY + const int32_t len = -2; // -1 is OK for NULLABLE_ARRAY. encoder.encode(len, buffer); absl::string_view data = {getRawData(buffer), 1024}; diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h index 8135d79a48d56..3c8dda15ec1e9 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.h +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -24,7 +24,7 @@ void assertStringViewIncrement(absl::string_view incremented, absl::string_view ASSERT_EQ(incremented.size(), original.size() - difference); } -// helper function +// Helper function converting buffer to raw bytes. const char* getRawData(const Buffer::OwnedImpl& buffer) { uint64_t num_slices = buffer.getRawSlices(nullptr, 0); STACK_ARRAY(slices, Buffer::RawSlice, num_slices); @@ -32,13 +32,13 @@ const char* getRawData(const Buffer::OwnedImpl& buffer) { return reinterpret_cast((slices[0]).mem_); } -// exactly what is says on the tin: -// 1. serialize expected using Encoder -// 2. deserialize byte array using testee deserializer -// 3. verify result = expected -// 4. verify that data pointer moved correct amount -// 5. feed testee more data -// 6. verify that nothing more was consumed +// Exactly what is says on the tin: +// 1. serialize expected using Encoder, +// 2. deserialize byte array using testee deserializer, +// 3. verify that testee is ready, and its result is equal to expected, +// 4. verify that data pointer moved correct amount, +// 5. feed testee more data, +// 6. verify that nothing more was consumed (because the testee has been ready since step 3). template void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { // given @@ -48,7 +48,7 @@ void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { EncodingContext encoder{-1}; const size_t written = encoder.encode(expected, buffer); - // tell parser that there is more data, it should never consume more than written + // Tell parser that there is more data, it should never consume more than written. const absl::string_view orig_data = {getRawData(buffer), 10 * written}; absl::string_view data = orig_data; @@ -69,9 +69,9 @@ void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { assertStringViewIncrement(data, orig_data, consumed); } -// does the same thing as the above test, -// but instead of providing whole data at one, it provides it in N one-byte chunks -// this verifies if deserializer keeps state properly (no overwrites etc.) +// Does the same thing as the above test, but instead of providing whole data at one, it provides +// it in N one-byte chunks. +// This verifies if deserializer keeps state properly (no overwrites etc.). template void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { // given @@ -87,7 +87,7 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { absl::string_view data = orig_data; size_t consumed = 0; for (size_t i = 0; i < written; ++i) { - data = {data.data(), 1}; // consume data byte-by-byte + data = {data.data(), 1}; // Consume data byte-by-byte. size_t step = testee.feed(data); consumed += step; ASSERT_EQ(step, 1); @@ -110,7 +110,7 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { ASSERT_EQ(more_data.size(), 1024); } -// wrapper to run both tests +// Wrapper to run both tests. template void serializeThenDeserializeAndCheckEquality(AT expected) { serializeThenDeserializeAndCheckEqualityInOneGo(expected); serializeThenDeserializeAndCheckEqualityWithChunks(expected); From 89c1a594a2648d7b32c99bb5358c18f0eb44edcc Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Fri, 22 Mar 2019 14:27:37 -0700 Subject: [PATCH 19/29] Attempt to point to Kafka codec in extension build; fix buffer overflows in unit tests with mocks Signed-off-by: Adam Kotwasinski --- source/extensions/extensions_build_config.bzl | 4 ++- .../kafka/kafka_request_parser_test.cc | 36 ++++++++++--------- .../network/kafka/serialization_utilities.h | 17 +++++---- 3 files changed, 34 insertions(+), 23 deletions(-) diff --git a/source/extensions/extensions_build_config.bzl b/source/extensions/extensions_build_config.bzl index bc6d0667569ea..d76149a5fca36 100644 --- a/source/extensions/extensions_build_config.bzl +++ b/source/extensions/extensions_build_config.bzl @@ -69,7 +69,9 @@ EXTENSIONS = { "envoy.filters.network.echo": "//source/extensions/filters/network/echo:config", "envoy.filters.network.ext_authz": "//source/extensions/filters/network/ext_authz:config", "envoy.filters.network.http_connection_manager": "//source/extensions/filters/network/http_connection_manager:config", - "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:config", + # NOTE: Kafka filter does not have a proper filter implemented right now. We are referencing to + # codec implementation that is going to be used by the filter. + "envoy.filters.network.kafka": "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "envoy.filters.network.mongo_proxy": "//source/extensions/filters/network/mongo_proxy:config", "envoy.filters.network.mysql_proxy": "//source/extensions/filters/network/mysql_proxy:config", "envoy.filters.network.ratelimit": "//source/extensions/filters/network/ratelimit:config", diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index c367ea3328530..0216968272a4c 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -17,8 +17,6 @@ const int32_t FAILED_DESERIALIZER_STEP = 13; class BufferBasedTest : public testing::Test { public: - Buffer::OwnedImpl& buffer() { return buffer_; } - const char* getBytes() { uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); STACK_ARRAY(slices, Buffer::RawSlice, num_slices); @@ -26,9 +24,18 @@ class BufferBasedTest : public testing::Test { return reinterpret_cast((slices[0]).mem_); } + template size_t putIntoBuffer(const T arg) { + EncodingContext encoder_{-1}; // Context's api_version is not used when serializing primitives. + return encoder_.encode(arg, buffer_); + } + + absl::string_view putGarbageIntoBuffer(size_t size = 10000) { + putIntoBuffer(Bytes(size)); + return {getBytes(), size}; + } + protected: Buffer::OwnedImpl buffer_; - EncodingContext encoder_{-1}; // Context's api_version is not used when serializing request header. }; class MockRequestParserResolver : public RequestParserResolver { @@ -43,7 +50,7 @@ TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { RequestStartParser testee{resolver}; int32_t request_len = 1234; - encoder_.encode(request_len, buffer()); + putIntoBuffer(request_len); const absl::string_view orig_data = {getBytes(), 1024}; absl::string_view data = orig_data; @@ -82,12 +89,12 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext const int32_t correlation_id{10}; const NullableString client_id{"aaa"}; size_t header_len = 0; - header_len += encoder_.encode(api_key, buffer()); - header_len += encoder_.encode(api_version, buffer()); - header_len += encoder_.encode(correlation_id, buffer()); - header_len += encoder_.encode(client_id, buffer()); + header_len += putIntoBuffer(api_key); + header_len += putIntoBuffer(api_version); + header_len += putIntoBuffer(correlation_id); + header_len += putIntoBuffer(client_id); - const absl::string_view orig_data = {getBytes(), 100000}; + const absl::string_view orig_data = putGarbageIntoBuffer(); absl::string_view data = orig_data; // when @@ -131,7 +138,7 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur RequestHeaderParser testee{parser_resolver, request_context, std::make_unique()}; - const absl::string_view orig_data = {getBytes(), 100000}; + const absl::string_view orig_data = putGarbageIntoBuffer(); absl::string_view data = orig_data; // when @@ -167,7 +174,7 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFee RequestContextSharedPtr request_context{new RequestContext{1024, {}}}; RequestParser testee{request_context}; - absl::string_view data = {getBytes(), 100000}; + absl::string_view data = putGarbageIntoBuffer(); // when bool caught = false; @@ -201,7 +208,7 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyBut RequestParser testee{request_context}; - const absl::string_view orig_data = {getBytes(), 100000}; + const absl::string_view orig_data = putGarbageIntoBuffer(); absl::string_view data = orig_data; // when @@ -225,10 +232,7 @@ TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { context->remaining_request_size_ = request_len; SentinelParser testee{context}; - const Bytes garbage(request_len * 2); - encoder_.encode(garbage, buffer()); - - const absl::string_view orig_data = {getBytes(), request_len * 2}; + const absl::string_view orig_data = putGarbageIntoBuffer(request_len * 2); absl::string_view data = orig_data; // when diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h index 3c8dda15ec1e9..ee4f1a19cea58 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.h +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -47,9 +47,11 @@ void serializeThenDeserializeAndCheckEqualityInOneGo(AT expected) { Buffer::OwnedImpl buffer; EncodingContext encoder{-1}; const size_t written = encoder.encode(expected, buffer); + // Insert garbage after serialized payload. + const size_t garbage_size = encoder.encode(Bytes(10000), buffer); // Tell parser that there is more data, it should never consume more than written. - const absl::string_view orig_data = {getRawData(buffer), 10 * written}; + const absl::string_view orig_data = {getRawData(buffer), written + garbage_size}; absl::string_view data = orig_data; // when @@ -80,8 +82,10 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { Buffer::OwnedImpl buffer; EncodingContext encoder{-1}; const size_t written = encoder.encode(expected, buffer); + // Insert garbage after serialized payload. + const size_t garbage_size = encoder.encode(Bytes(10000), buffer); - const absl::string_view orig_data = {getRawData(buffer), written}; + const absl::string_view orig_data = {getRawData(buffer), written + garbage_size}; // when absl::string_view data = orig_data; @@ -98,16 +102,17 @@ void serializeThenDeserializeAndCheckEqualityWithChunks(AT expected) { ASSERT_EQ(consumed, written); ASSERT_EQ(testee.ready(), true); ASSERT_EQ(testee.get(), expected); - assertStringViewIncrement(data, orig_data, consumed); + + ASSERT_EQ(data.data(), orig_data.data() + consumed); // when - 2 - absl::string_view more_data = {data.data(), 1024}; + absl::string_view more_data = {data.data(), garbage_size}; const size_t consumed2 = testee.feed(more_data); // then - 2 (nothing changes) ASSERT_EQ(consumed2, 0); - ASSERT_EQ(more_data.data(), orig_data.data() + consumed); - ASSERT_EQ(more_data.size(), 1024); + ASSERT_EQ(more_data.data(), data.data()); + ASSERT_EQ(more_data.size(), garbage_size); } // Wrapper to run both tests. From acdbcb3a13e7517947da00cebf32208951b1bbab Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Fri, 22 Mar 2019 16:30:17 -0700 Subject: [PATCH 20/29] Explicitly provide type of test values in generated Kafka tests Signed-off-by: Adam Kotwasinski --- .../kafka/protocol_code_generator/kafka_generator.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index b505f8706888d..4df8a58f69021 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -391,10 +391,10 @@ class Primitive(TypeSpecification): KAFKA_TYPE_TO_EXAMPLE_VALUE_FOR_TEST = { 'string': '"string"', 'bool': 'false', - 'int8': '8', - 'int16': '16', - 'int32': '32', - 'int64': '64ll', + 'int8': 'static_cast(8)', + 'int16': 'static_cast(16)', + 'int32': 'static_cast(32)', + 'int64': 'static_cast(64)', 'bytes': 'Bytes({0, 1, 2, 3})', } From f5f5f3215bc5d515cea31ab4b6e430417b5eb461 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Mon, 25 Mar 2019 11:05:21 -0700 Subject: [PATCH 21/29] Reorganize test code Signed-off-by: Adam Kotwasinski --- ...quest_codec_request_integration_test_cc.j2 | 21 +---------- test/extensions/filters/network/kafka/BUILD | 4 +++ .../kafka/request_codec_integration_test.cc | 19 +--------- .../network/kafka/serialization_utilities.cc | 31 ++++++++++++++++ .../network/kafka/serialization_utilities.h | 35 +++++++++++++------ 5 files changed, 61 insertions(+), 49 deletions(-) create mode 100644 test/extensions/filters/network/kafka/serialization_utilities.cc diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 index 65c0d2b475f1d..cd2ce77f4ea87 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 @@ -12,6 +12,7 @@ #include "extensions/filters/network/kafka/request_codec.h" #include "extensions/filters/network/kafka/requests.h" +#include "test/extensions/filters/network/kafka/serialization_utilities.h" #include "test/mocks/server/mocks.h" #include "gtest/gtest.h" @@ -28,26 +29,6 @@ protected: Buffer::OwnedImpl buffer_; }; -class CapturingRequestCallback : public RequestCallback { -public: - virtual void onMessage(MessageSharedPtr request) override; - - const std::vector& getCaptured() const; - -private: - std::vector captured_; -}; - -typedef std::shared_ptr CapturingRequestCallbackSharedPtr; - -void CapturingRequestCallback::onMessage(MessageSharedPtr message) { - captured_.push_back(message); -} - -const std::vector& CapturingRequestCallback::getCaptured() const { - return captured_; -} - {% for request_type in request_types %} // Integration test for {{ request_type.name }} messages. diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index 7c9343ad920bd..e065402656545 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -14,9 +14,11 @@ envoy_package() envoy_cc_test_library( name = "serialization_utilities_lib", + srcs = ["serialization_utilities.cc"], hdrs = ["serialization_utilities.h"], deps = [ "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "//source/extensions/filters/network/kafka:serialization_lib", ], ) @@ -82,6 +84,7 @@ envoy_extension_cc_test( srcs = ["request_codec_integration_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ + ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "//test/mocks/server:server_mocks", ], @@ -92,6 +95,7 @@ envoy_extension_cc_test( srcs = ["request_codec_request_integration_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ + ":serialization_utilities_lib", "//source/extensions/filters/network/kafka:kafka_request_codec_lib", "//test/mocks/server:server_mocks", ], diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index 58148c09f64cc..d692801fe1787 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -1,5 +1,6 @@ #include "extensions/filters/network/kafka/request_codec.h" +#include "test/extensions/filters/network/kafka/serialization_utilities.h" #include "test/mocks/server/mocks.h" #include "gtest/gtest.h" @@ -16,24 +17,6 @@ class RequestDecoderTest : public testing::Test { Buffer::OwnedImpl buffer_; }; -class CapturingRequestCallback : public RequestCallback { -public: - virtual void onMessage(MessageSharedPtr request) override; - - const std::vector& getCaptured() const; - -private: - std::vector captured_; -}; - -typedef std::shared_ptr CapturingRequestCallbackSharedPtr; - -void CapturingRequestCallback::onMessage(MessageSharedPtr message) { captured_.push_back(message); } - -const std::vector& CapturingRequestCallback::getCaptured() const { - return captured_; -} - // Other request types are tested in (generated) 'request_codec_request_integration_test.cc'. TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { // given diff --git a/test/extensions/filters/network/kafka/serialization_utilities.cc b/test/extensions/filters/network/kafka/serialization_utilities.cc new file mode 100644 index 0000000000000..63eca308c269c --- /dev/null +++ b/test/extensions/filters/network/kafka/serialization_utilities.cc @@ -0,0 +1,31 @@ +#include "test/extensions/filters/network/kafka/serialization_utilities.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +void assertStringViewIncrement(absl::string_view incremented, absl::string_view original, + size_t difference) { + + ASSERT_EQ(incremented.data(), original.data() + difference); + ASSERT_EQ(incremented.size(), original.size() - difference); +} + +const char* getRawData(const Buffer::OwnedImpl& buffer) { + uint64_t num_slices = buffer.getRawSlices(nullptr, 0); + STACK_ARRAY(slices, Buffer::RawSlice, num_slices); + buffer.getRawSlices(slices.begin(), num_slices); + return reinterpret_cast((slices[0]).mem_); +} + +void CapturingRequestCallback::onMessage(MessageSharedPtr message) { captured_.push_back(message); } + +const std::vector& CapturingRequestCallback::getCaptured() const { + return captured_; +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h index ee4f1a19cea58..37a1540c040cd 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.h +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -3,6 +3,7 @@ #include "common/buffer/buffer_impl.h" #include "common/common/stack_array.h" +#include "extensions/filters/network/kafka/request_codec.h" #include "extensions/filters/network/kafka/serialization.h" #include "absl/strings/string_view.h" @@ -18,19 +19,10 @@ namespace Kafka { * by 'difference' bytes. */ void assertStringViewIncrement(absl::string_view incremented, absl::string_view original, - size_t difference) { - - ASSERT_EQ(incremented.data(), original.data() + difference); - ASSERT_EQ(incremented.size(), original.size() - difference); -} + size_t difference); // Helper function converting buffer to raw bytes. -const char* getRawData(const Buffer::OwnedImpl& buffer) { - uint64_t num_slices = buffer.getRawSlices(nullptr, 0); - STACK_ARRAY(slices, Buffer::RawSlice, num_slices); - buffer.getRawSlices(slices.begin(), num_slices); - return reinterpret_cast((slices[0]).mem_); -} +const char* getRawData(const Buffer::OwnedImpl& buffer); // Exactly what is says on the tin: // 1. serialize expected using Encoder, @@ -121,6 +113,27 @@ template void serializeThenDeserializeAndCheckEqualit serializeThenDeserializeAndCheckEqualityWithChunks(expected); } +/** + * Request callback that captures the messages. + */ +class CapturingRequestCallback : public RequestCallback { +public: + /** + * Stores the message. + */ + virtual void onMessage(MessageSharedPtr request) override; + + /** + * Returns the stored messages. + */ + const std::vector& getCaptured() const; + +private: + std::vector captured_; +}; + +typedef std::shared_ptr CapturingRequestCallbackSharedPtr; + } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions From 20c7bf119132a95ff894cdf21d71f28f2314a7bf Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 26 Mar 2019 10:57:29 -0700 Subject: [PATCH 22/29] Create separate test class for each of Kafka tests; add missing formatting Signed-off-by: Adam Kotwasinski --- .../complex_type_template.j2 | 22 +++++----- .../kafka_generator.py | 44 ++++++++++++------- .../kafka_request_resolver_cc.j2 | 10 +++-- ...quest_codec_request_integration_test_cc.j2 | 16 ++++--- .../protocol_code_generator/request_parser.j2 | 12 ++--- .../protocol_code_generator/requests_h.j2 | 10 ++--- .../requests_test_cc.j2 | 12 ++--- .../serialization_composite_generator.py | 6 ++- .../serialization_composite_h.j2 | 11 +++-- .../serialization_composite_test_cc.j2 | 7 ++- .../kafka/kafka_request_parser_test.cc | 14 +++--- .../kafka/request_codec_integration_test.cc | 6 +-- .../network/kafka/request_codec_unit_test.cc | 12 ++--- 13 files changed, 105 insertions(+), 77 deletions(-) diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 index 2395c39ddaab4..46251d5fb5af0 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/complex_type_template.j2 @@ -1,20 +1,20 @@ {# - Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest, FetchRequestTopic). + Template for structure representing a composite entity in Kafka protocol (e.g. FetchRequest). Rendered templates for each structure in Kafka protocol will be put into 'requests.h' file. - Each structure is capable of holding all versions of given entity (what means its fields are actually a superset - of union of all versions' fields). Each version has a dedicated deserializer (named $requestV$versionDeserializer), - which calls the matching constructor. + Each structure is capable of holding all versions of given entity (what means its fields are + actually a superset of union of all versions' fields). Each version has a dedicated deserializer + (named $requestV$versionDeserializer), which calls the matching constructor. - To serialize, it is necessary to pass the encoding context (that contains the version that's being serialized). - Depending on the version, the fields will be written to the buffer. + To serialize, it is necessary to pass the encoding context (that contains the version that's + being serialized). Depending on the version, the fields will be written to the buffer. #} struct {{ complex_type.name }} { {# Constructors invoked by deserializers. - Each constructor has a signature that matches the fields in at least one version (as sometimes there are - different Kafka versions that are actually composed of precisely the same fields). + Each constructor has a signature that matches the fields in at least one version (as sometimes + there are different Kafka versions that are actually composed of precisely the same fields). #} {% for field in complex_type.fields %} const {{ field.field_declaration() }}_;{% endfor %} @@ -27,7 +27,8 @@ struct {{ complex_type.name }} { size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const { const int16_t api_version = encoder.apiVersion(); size_t written{0};{% for field in complex_type.fields %} - if (api_version >= {{ field.version_usage[0] }} && api_version < {{ field.version_usage[-1] + 1 }}) { + if (api_version >= {{ field.version_usage[0] }} + && api_version < {{ field.version_usage[-1] + 1 }}) { written += encoder.encode({{ field.name }}_, dst); }{% endfor %} return written; @@ -56,7 +57,8 @@ struct {{ complex_type.name }} { class {{ complex_type.name }}V{{ field_list.version }}Deserializer: public CompositeDeserializerWith{{ field_list.field_count() }}Delegates< {{ complex_type.name }} - {% for field in field_list.used_fields() %}, {{ field.deserializer_name_in_version(field_list.version) }} + {% for field in field_list.used_fields() %}, + {{ field.deserializer_name_in_version(field_list.version) }} {% endfor %}>{}; {% endfor %} diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index 4df8a58f69021..1d3b4886fa35a 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -14,22 +14,26 @@ def main(): COMMAND : 'generate-source', to generate source files, 'generate-test', to generate test files. OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', - if generate-test: location of 'requests_test.cc', 'request_codec_request_integration_test.cc'. + if generate-test: location of 'requests_test.cc', + 'request_codec_request_integration_test.cc'. INPUT_FILES: Kafka protocol json files to be processed. Kafka spec files are provided in Kafka clients jar file. When generating source code, it creates: - requests.h - definition of all the structures/deserializers/parsers related to Kafka requests, - - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from requests.h. + - kafka_request_resolver.cc - resolver that binds api_key & api_version to parsers from + requests.h. When generating test code, it creates: - requests_test.cc - serialization/deserialization tests for kafka structures, - - request_codec_request_integration_test.cc - integration test for all request operations using the codec API. + - request_codec_request_integration_test.cc - integration test for all request operations using + the codec API. Templates used are: - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2, - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2, - to create 'requests_test.cc': requests_test_cc.j2, - - to create 'request_codec_request_integration_test.cc' - request_codec_request_integration_test_cc.j2. + - to create 'request_codec_request_integration_test.cc' - + request_codec_request_integration_test_cc.j2. """ import sys @@ -109,8 +113,8 @@ def main(): def parse_request(spec): """ Parse a given structure into a request. - Request is just a complex type, that has name & version information kept in differently named fields, compared to - sub-structures in a request. + Request is just a complex type, that has name & version information kept in differently named + fields, compared to sub-structures in a request. """ request_type_name = spec['name'] request_versions = Statics.parse_version_string(spec['validVersions'], 2 << 16 - 1) @@ -120,7 +124,8 @@ def parse_request(spec): def parse_complex_type(type_name, field_spec, versions): """ - Parse given complex type, returning a structure that holds its name, field specification and allowed versions. + Parse given complex type, returning a structure that holds its name, field specification and + allowed versions. """ fields = [] for child_field in field_spec['fields']: @@ -131,8 +136,9 @@ def parse_complex_type(type_name, field_spec, versions): def parse_field(field_spec, highest_possible_version): """ - Parse given field, returning a structure holding the name, type, and versions when this field is actually used - (nullable or not). Obviously, field cannot be used in version higher than its type's usage. + Parse given field, returning a structure holding the name, type, and versions when this field is + actually used (nullable or not). Obviously, field cannot be used in version higher than its + type's usage. """ version_usage = Statics.parse_version_string(field_spec['versions'], highest_possible_version) version_usage_as_nullable = Statics.parse_version_string( @@ -144,10 +150,11 @@ def parse_field(field_spec, highest_possible_version): def parse_type(type_name, field_spec, highest_possible_version): """ - Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one (== struct). + Parse a given type element - returns an array type, primitive (e.g. uint32_t) or complex one. """ if (type_name.startswith('[]')): - # In spec files, array types are defined as `[]underlying_type` instead of having its own element with type inside. + # In spec files, array types are defined as `[]underlying_type` instead of having its own + # element with type inside. underlying_type = parse_type(type_name[2:], field_spec, highest_possible_version) return Array(underlying_type) else: @@ -178,8 +185,8 @@ def parse_version_string(raw_versions, highest_possible_version): class FieldList: """ - List of fields used by given entity (request or child structure) in given request version (as fields get added - or removed across versions). + List of fields used by given entity (request or child structure) in given request version + (as fields get added or removed across versions). """ def __init__(self, version, fields): @@ -195,7 +202,8 @@ def used_fields(self): def constructor_signature(self): """ Return constructor signature. - Multiple versions of the same structure can have identical signatures (due to version bumps in Kafka). + Multiple versions of the same structure can have identical signatures (due to version bumps in + Kafka). """ parameter_spec = map(lambda x: x.parameter_declaration(self.version), self.used_fields()) return ', '.join(parameter_spec) @@ -203,7 +211,8 @@ def constructor_signature(self): def constructor_init_list(self): """ Renders member initialization list in constructor. - Takes care of potential optional conversions (as field could be T in V1, but optional in V2). + Takes care of potential optional conversions (as field could be T in V1, but optional + in V2). """ init_list = [] for field in self.fields: @@ -461,8 +470,9 @@ def get_extra(self, key): def compute_constructors(self): """ - Field lists for different versions may not differ (as Kafka can bump version without any changes). - But constructors need to be unique, so we need to remove duplicates if the signatures match. + Field lists for different versions may not differ (as Kafka can bump version without any + changes). But constructors need to be unique, so we need to remove duplicates if the signatures + match. """ signature_to_constructor = {} for field_list in self.compute_field_lists(): diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index 553a761945b46..9a8de2bf087d4 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -1,6 +1,7 @@ {# Template for 'kafka_request_resolver.cc'. - Defines default Kafka request resolver, that uses request parsers in (also generated) 'requests.h'. + Defines default Kafka request resolver, that uses request parsers in (also generated) + 'requests.h'. #} #include "extensions/filters/network/kafka/requests.h" #include "extensions/filters/network/kafka/kafka_request_parser.h" @@ -13,8 +14,8 @@ namespace Kafka { /** * Creates a parser that corresponds to provided key and version. - * If corresponding parser cannot be found (what means a newer version of Kafka protocol), a sentinel parser is - * returned. + * If corresponding parser cannot be found (what means a newer version of Kafka protocol), + * a sentinel parser is returned. * @param api_key Kafka request key * @param api_version Kafka request's version * @param context parse context @@ -23,7 +24,8 @@ ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api RequestContextSharedPtr context) const { {% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} - if ({{ request_type.get_extra('api_key') }} == api_key && {{ field_list.version }} == api_version) { + if ({{ request_type.get_extra('api_key') }} == api_key + && {{ field_list.version }} == api_version) { return std::make_shared<{{ request_type.name }}V{{ field_list.version }}Parser>(context); }{% endfor %}{% endfor %} return std::make_shared(context); diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 index cd2ce77f4ea87..fb542372a1295 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 @@ -22,7 +22,7 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -class RequestCodecIntegrationTest : public testing::Test { +class RequestCodecRequestTest : public testing::Test { protected: template void putInBuffer(T arg); @@ -33,17 +33,17 @@ protected: // Integration test for {{ request_type.name }} messages. -TEST_F(RequestCodecIntegrationTest, shouldHandle{{ request_type.name }}Messages) { +TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { // given using Request = ConcreteRequest<{{ request_type.name }}>; std::vector sent; - int32_t correlation_id = 0; + int32_t correlation = 0; {% for field_list in request_type.compute_field_lists() %} for (int i = 0; i < 100; ++i ) { const RequestHeader header = - { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation_id++, "client-id" }; + { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation++, "id" }; const {{ request_type.name }} data = { {{ field_list.example_value() }} }; const Request request = {header, data}; putInBuffer(request); @@ -52,8 +52,10 @@ TEST_F(RequestCodecIntegrationTest, shouldHandle{{ request_type.name }}Messages) {% endfor %} const InitialParserFactory& initial_parser_factory = InitialParserFactory::getDefaultInstance(); - const RequestParserResolver& request_parser_resolver = RequestParserResolver::getDefaultInstance(); - const CapturingRequestCallbackSharedPtr request_callback = std::make_shared(); + const RequestParserResolver& request_parser_resolver = + RequestParserResolver::getDefaultInstance(); + const CapturingRequestCallbackSharedPtr request_callback = + std::make_shared(); RequestDecoder testee{initial_parser_factory, request_parser_resolver, {request_callback}}; @@ -73,7 +75,7 @@ TEST_F(RequestCodecIntegrationTest, shouldHandle{{ request_type.name }}Messages) {% endfor %} template -void RequestCodecIntegrationTest::putInBuffer(const T arg) { +void RequestCodecRequestTest::putInBuffer(const T arg) { MessageEncoderImpl serializer{buffer_}; serializer.encode(arg); } diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 index b01f52d2eae36..33f67b3f85f83 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -1,14 +1,16 @@ {# - Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest, FetchRequest etc.). + Template for top-level structure representing a request in Kafka protocol (e.g. ProduceRequest). Rendered templates for each request in Kafka protocol will be put into 'requests.h' file. - This template handles binding the top-level structure deserializer (e.g. ProduceRequestV0Deserializer) with - RequestParser. These parsers are then used by RequestParserResolver instance depending on received Kafka api key & - api version (see 'kafka_request_resolver_cc.j2'). + This template handles binding the top-level structure deserializer + (e.g. ProduceRequestV0Deserializer) with RequestParser. These parsers are then used by + RequestParserResolver instance depending on received Kafka api key & api version + (see 'kafka_request_resolver_cc.j2'). #} {% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: - public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { + public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> +{ public: {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; }; diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 index 66df9cf56abcf..ff85d19410d07 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_h.j2 @@ -6,19 +6,19 @@ - 1 top-level structure corresponding to the request (e.g. `struct FetchRequest`), - N deserializers for top-level structure, one for each request version, - N parsers binding each deserializer with parser, - - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that compose into top-level - structure, - - deserializers for each child structure (M = number of versions where structure is actually used). + - 0+ child structures (e.g. `struct FetchRequestTopic`, `FetchRequestPartition`) that are used by + the request's top-level structure, + - deserializers for each child structure. So for example, for FetchRequest we have: - struct FetchRequest, - FetchRequestV0Deserializer, FetchRequestV1Deserializer, FetchRequestV2Deserializer, etc., - FetchRequestV0Parser, FetchRequestV1Parser, FetchRequestV2Parser, etc., - struct FetchRequestTopic, - - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, FetchRequestTopicV2Deserializer, etc. + - FetchRequestTopicV0Deserializer, FetchRequestTopicV1Deserializer, etc. (because topic data is present in every FetchRequest version), - struct FetchRequestPartition, - - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, FetchRequestPartitionV2Deserializer, etc. + - FetchRequestPartitionV0Deserializer, FetchRequestPartitionV1Deserializer, etc. (because partition data is present in every FetchRequestTopic version). #} #pragma once diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 8d164bdf1fea5..155113d178ff3 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -16,7 +16,7 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -class RequestDecoderTest : public testing::Test { +class RequestTest : public testing::Test { public: Buffer::OwnedImpl buffer_; @@ -33,7 +33,7 @@ public: * Takes an instance of a request, serializes it, then deserializes it. * This method gets executed for every request * version pair. */ -template std::shared_ptr RequestDecoderTest::serializeAndDeserialize(T request) { +template std::shared_ptr RequestTest::serializeAndDeserialize(T request) { MessageEncoderImpl serializer{buffer_}; serializer.encode(request); @@ -41,7 +41,8 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; MessageSharedPtr receivedMessage; - EXPECT_CALL(*mock_listener, onMessage(testing::_)).WillOnce(testing::SaveArg<0>(&receivedMessage)); + EXPECT_CALL(*mock_listener, onMessage(testing::_)) + .WillOnce(testing::SaveArg<0>(&receivedMessage)); testee.onData(buffer_); @@ -50,10 +51,11 @@ template std::shared_ptr RequestDecoderTest::serializeAndDeseria {# Concrete tests for each request_type and version (field_list). - Each request is naively constructed using some default values (put "string" as std::string, 32 as uint32_t, etc.). + Each request is naively constructed using some default values + (put "string" as std::string, 32 as uint32_t, etc.). #} {% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} -TEST_F(RequestDecoderTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { +TEST_F(RequestTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { // given {{ request_type.name }} data = { {{ field_list.example_value() }} }; ConcreteRequest<{{ request_type.name }}> request = { { diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py index 21478845a61c9..100bf7593c71e 100755 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_generator.py @@ -6,7 +6,8 @@ def main(): Serialization composite generator script ~~~~~~~~~~~~~~~~~~~~~~~~~~~~~ Generates main&test source code files for composite deserializers. - The files are generated, as they are extremely repetitive (composite deserializer for 0..9 sub-deserializers). + The files are generated, as they are extremely repetitive (composite deserializer for 0..9 + sub-deserializers). Usage: serialization_composite_generator.py COMMAND LOCATION_OF_OUTPUT_FILE @@ -17,7 +18,8 @@ def main(): if generate-test: location of 'serialization_composite_test.cc'. When generating source code, it creates: - - serialization_composite.h - header with declarations of CompositeDeserializerWith???Delegates classes. + - serialization_composite.h - header with declarations of CompositeDeserializerWith???Delegates + classes. When generating test code, it creates: - serialization_composite_test.cc - tests for these classes. diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 index 44910f99161dd..271d0084332f5 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 @@ -2,7 +2,8 @@ Creates 'serialization_composite.h'. Template for composite serializers (the CompositeDeserializerWith_N_Delegates classes). - Covers the corner case of 0 delegates, and then uses templating to create declarations for 1..N variants. + Covers the corner case of 0 delegates, and then uses templating to create declarations for 1..N + variants. #} #pragma once @@ -53,8 +54,9 @@ public: {% for field_count in counts %} /** * Composite deserializer that uses {{ field_count }} deserializer(s). - * Passes data to each of the underlying deserializers (deserializers that are already ready do not consume data, - * so it's safe). The composite deserializer is ready when the last deserializer is ready (what means that all + * Passes data to each of the underlying deserializers (deserializers that are already ready do not + * consume data, so it's safe). + * The composite deserializer is ready when the last deserializer is ready (what means that all * deserializers before it are ready too). * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } * @@ -62,7 +64,8 @@ public: * @param DeserializerType{{ field }} deserializer {{ field }} {% endfor %} */ template < - typename ResponseType{% for field in range(1, field_count + 1) %}, typename DeserializerType{{ field }}{% endfor %} + typename ResponseType{% for field in range(1, field_count + 1) %}, + typename DeserializerType{{ field }}{% endfor %} > class CompositeDeserializerWith{{ field_count }}Delegates : public Deserializer { public: diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 index 22533e3a78088..86658ccd38803 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 @@ -15,7 +15,8 @@ namespace NetworkFilters { namespace Kafka { /** - * Tests in this class are supposed to check whether serialization operations on composite deserializers are correct. + * Tests in this class are supposed to check whether serialization operations on composite + * deserializers are correct. */ // Tests for composite deserializer with 0 fields (corner case). @@ -57,7 +58,9 @@ struct CompositeResultWith{{ field_count }}Fields { } bool operator==(const CompositeResultWith{{ field_count }}Fields& rhs) const { - return true{% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_{% endfor %}; + return true + {% for field in range(1, field_count + 1) %} && field{{ field }}_ == rhs.field{{ field }}_ + {% endfor %}; } }; diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index 0216968272a4c..43c62868e988e 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -15,7 +15,7 @@ namespace Kafka { const int32_t FAILED_DESERIALIZER_STEP = 13; -class BufferBasedTest : public testing::Test { +class KafkaRequestParserTest : public testing::Test { public: const char* getBytes() { uint64_t num_slices = buffer_.getRawSlices(nullptr, 0); @@ -44,7 +44,7 @@ class MockRequestParserResolver : public RequestParserResolver { MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); }; -TEST_F(BufferBasedTest, RequestStartParserTestShouldReturnRequestHeaderParser) { +TEST_F(KafkaRequestParserTest, RequestStartParserTestShouldReturnRequestHeaderParser) { // given MockRequestParserResolver resolver{}; RequestStartParser testee{resolver}; @@ -73,7 +73,7 @@ class MockParser : public Parser { } }; -TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNextParser) { +TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldExtractHeaderAndResolveNextParser) { // given const MockRequestParserResolver parser_resolver; const ParserSharedPtr parser{new MockParser{}}; @@ -112,7 +112,7 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldExtractHeaderDataAndResolveNext assertStringViewIncrement(data, orig_data, header_len); } -TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDuringFeeding) { +TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldHandleExceptionsDuringFeeding) { // given // This deserializer throws during feeding. @@ -155,7 +155,7 @@ TEST_F(BufferBasedTest, RequestHeaderParserShouldHandleDeserializerExceptionsDur assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); } -TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { +TEST_F(KafkaRequestParserTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { // given // This deserializer throws during feeding. @@ -201,7 +201,7 @@ class SomeBytesDeserializer : public Deserializer { int32_t get() const override { return 0; }; }; -TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyButLeavingData) { +TEST_F(KafkaRequestParserTest, RequestParserShouldHandleDeserializerReturningReadyButLeavingData) { // given const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; @@ -225,7 +225,7 @@ TEST_F(BufferBasedTest, RequestParserShouldHandleDeserializerClaimingItsReadyBut assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); } -TEST_F(BufferBasedTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { +TEST_F(KafkaRequestParserTest, SentinelParserShouldConsumeDataUntilEndOfRequest) { // given const int32_t request_len = 1000; RequestContextSharedPtr context{new RequestContext()}; diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index d692801fe1787..67c61b4e616e1 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -10,7 +10,7 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -class RequestDecoderTest : public testing::Test { +class RequestCodecIntegrationTest : public testing::Test { protected: template void putInBuffer(T arg); @@ -18,7 +18,7 @@ class RequestDecoderTest : public testing::Test { }; // Other request types are tested in (generated) 'request_codec_request_integration_test.cc'. -TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { +TEST_F(RequestCodecIntegrationTest, shouldProduceAbortedMessageOnUnknownData) { // given // As real api keys have values below 100, the messages generated in this loop should not be // recognized by the codec. @@ -56,7 +56,7 @@ TEST_F(RequestDecoderTest, shouldProduceAbortedMessageOnUnknownData) { } // Helper function. -template void RequestDecoderTest::putInBuffer(T arg) { +template void RequestCodecIntegrationTest::putInBuffer(T arg) { MessageEncoderImpl serializer{buffer_}; serializer.encode(arg); } diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc index 4906bbd56711e..70bd819077afa 100644 --- a/test/extensions/filters/network/kafka/request_codec_unit_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -42,7 +42,7 @@ class MockRequestCallback : public RequestCallback { typedef std::shared_ptr MockRequestCallbackSharedPtr; -class RequestDecoderTest : public testing::Test { +class RequestCodecUnitTest : public testing::Test { protected: template void putInBuffer(T arg); @@ -58,7 +58,7 @@ ParseResponse consumeOneByte(absl::string_view& data) { return ParseResponse::stillWaiting(); } -TEST_F(RequestDecoderTest, shouldDoNothingIfParserNeverReturnsMessage) { +TEST_F(RequestCodecUnitTest, shouldDoNothingIfParserNeverReturnsMessage) { // given putInBuffer(ConcreteRequest{{}, 0}); @@ -76,7 +76,7 @@ TEST_F(RequestDecoderTest, shouldDoNothingIfParserNeverReturnsMessage) { // There were no interactions with `request_callback`. } -TEST_F(RequestDecoderTest, shouldUseNewParserAsResponse) { +TEST_F(RequestCodecUnitTest, shouldUseNewParserAsResponse) { // given putInBuffer(ConcreteRequest{{}, 0}); @@ -98,7 +98,7 @@ TEST_F(RequestDecoderTest, shouldUseNewParserAsResponse) { // There were no interactions with `request_callback`. } -TEST_F(RequestDecoderTest, shouldReturnParsedMessageAndReinitialize) { +TEST_F(RequestCodecUnitTest, shouldReturnParsedMessageAndReinitialize) { // given putInBuffer(ConcreteRequest{{}, 0}); @@ -124,7 +124,7 @@ TEST_F(RequestDecoderTest, shouldReturnParsedMessageAndReinitialize) { // There was only one message sent to `request_callback`. } -TEST_F(RequestDecoderTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { +TEST_F(RequestCodecUnitTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { // given putInBuffer(ConcreteRequest{{}, 0}); @@ -159,7 +159,7 @@ TEST_F(RequestDecoderTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { } // Helper function. -template void RequestDecoderTest::putInBuffer(T arg) { +template void RequestCodecUnitTest::putInBuffer(T arg) { MessageEncoderImpl serializer{buffer_}; serializer.encode(arg); } From 81c97c833cd5b61e5584cf1c9e5e3b694916cc2e Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Tue, 26 Mar 2019 16:08:55 -0700 Subject: [PATCH 23/29] Put Kafka tests in dedicated namespaces to avoid duplicate mock classes when running coverage builds; some renames Signed-off-by: Adam Kotwasinski --- .../protocol_code_generator/kafka_generator.py | 15 ++++++--------- ...est_cc.j2 => request_codec_request_test_cc.j2} | 4 +++- .../protocol_code_generator/requests_test_cc.j2 | 2 ++ test/extensions/filters/network/kafka/BUILD | 8 ++++---- .../network/kafka/kafka_request_parser_test.cc | 2 ++ .../kafka/request_codec_integration_test.cc | 4 +++- .../network/kafka/request_codec_unit_test.cc | 2 ++ .../filters/network/kafka/serialization_test.cc | 2 ++ 8 files changed, 24 insertions(+), 15 deletions(-) rename source/extensions/filters/network/kafka/protocol_code_generator/{request_codec_request_integration_test_cc.j2 => request_codec_request_test_cc.j2} (95%) diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py index 1d3b4886fa35a..663c8a58fbf84 100755 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_generator.py @@ -14,8 +14,7 @@ def main(): COMMAND : 'generate-source', to generate source files, 'generate-test', to generate test files. OUTPUT_FILES : if generate-source: location of 'requests.h' and 'kafka_request_resolver.cc', - if generate-test: location of 'requests_test.cc', - 'request_codec_request_integration_test.cc'. + if generate-test: location of 'requests_test.cc', 'request_codec_request_test.cc'. INPUT_FILES: Kafka protocol json files to be processed. Kafka spec files are provided in Kafka clients jar file. @@ -25,15 +24,13 @@ def main(): requests.h. When generating test code, it creates: - requests_test.cc - serialization/deserialization tests for kafka structures, - - request_codec_request_integration_test.cc - integration test for all request operations using - the codec API. + - request_codec_request_test.cc - test for all request operations using the codec API. Templates used are: - to create 'requests.h': requests_h.j2, complex_type_template.j2, request_parser.j2, - to create 'kafka_request_resolver.cc': kafka_request_resolver_cc.j2, - to create 'requests_test.cc': requests_test_cc.j2, - - to create 'request_codec_request_integration_test.cc' - - request_codec_request_integration_test_cc.j2. + - to create 'request_codec_request_test.cc' - request_codec_request_test_cc.j2. """ import sys @@ -46,7 +43,7 @@ def main(): input_files = sys.argv[4:] elif 'generate-test' == command: requests_test_cc_file = os.path.abspath(sys.argv[2]) - request_codec_request_integration_test_cc_file = os.path.abspath(sys.argv[3]) + request_codec_request_test_cc_file = os.path.abspath(sys.argv[3]) input_files = sys.argv[4:] else: raise ValueError('invalid command: ' + command) @@ -103,10 +100,10 @@ def main(): with open(requests_test_cc_file, 'w') as fd: fd.write(contents) - template = RenderingHelper.get_template('request_codec_request_integration_test_cc.j2') + template = RenderingHelper.get_template('request_codec_request_test_cc.j2') contents = template.render(request_types=requests) - with open(request_codec_request_integration_test_cc_file, 'w') as fd: + with open(request_codec_request_test_cc_file, 'w') as fd: fd.write(contents) diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 similarity index 95% rename from source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 rename to source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 index fb542372a1295..6d53f01ecfa46 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_integration_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 @@ -1,5 +1,5 @@ {# - Template for 'request_codec_request_integration_test.cc'. + Template for 'request_codec_request_test.cc'. Provides integration tests using Kafka codec. The tests do the following: @@ -21,6 +21,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace RequestCodecRequestTest { class RequestCodecRequestTest : public testing::Test { protected: @@ -80,6 +81,7 @@ void RequestCodecRequestTest::putInBuffer(const T arg) { serializer.encode(arg); } +} // namespace RequestCodecRequestTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 155113d178ff3..e2abbf6b7d979 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -15,6 +15,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace RequestTest { class RequestTest : public testing::Test { public: @@ -70,6 +71,7 @@ TEST_F(RequestTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) } {% endfor %}{% endfor %} +} // namespace RequestTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index e065402656545..d8f8fc082501d 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -91,8 +91,8 @@ envoy_extension_cc_test( ) envoy_extension_cc_test( - name = "request_codec_request_integration_test", - srcs = ["request_codec_request_integration_test.cc"], + name = "request_codec_request_test", + srcs = ["request_codec_request_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ ":serialization_utilities_lib", @@ -118,11 +118,11 @@ genrule( ], outs = [ "requests_test.cc", - "request_codec_request_integration_test.cc", + "request_codec_request_test.cc", ], cmd = """ ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ - $(location requests_test.cc) $(location request_codec_request_integration_test.cc) \ + $(location requests_test.cc) $(location request_codec_request_test.cc) \ $(SRCS) """, tools = [ diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index 43c62868e988e..c23c392840d57 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -12,6 +12,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace KafkaRequestParserTest { const int32_t FAILED_DESERIALIZER_STEP = 13; @@ -248,6 +249,7 @@ TEST_F(KafkaRequestParserTest, SentinelParserShouldConsumeDataUntilEndOfRequest) assertStringViewIncrement(data, orig_data, request_len); } +} // namespace KafkaRequestParserTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index 67c61b4e616e1..b82042ebf8635 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -9,6 +9,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace RequestCodecIntegrationTest { class RequestCodecIntegrationTest : public testing::Test { protected: @@ -17,7 +18,7 @@ class RequestCodecIntegrationTest : public testing::Test { Buffer::OwnedImpl buffer_; }; -// Other request types are tested in (generated) 'request_codec_request_integration_test.cc'. +// Other request types are tested in (generated) 'request_codec_request_test.cc'. TEST_F(RequestCodecIntegrationTest, shouldProduceAbortedMessageOnUnknownData) { // given // As real api keys have values below 100, the messages generated in this loop should not be @@ -61,6 +62,7 @@ template void RequestCodecIntegrationTest::putInBuffer(T arg) { serializer.encode(arg); } +} // namespace RequestCodecIntegrationTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc index 70bd819077afa..39f9256ae660b 100644 --- a/test/extensions/filters/network/kafka/request_codec_unit_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -16,6 +16,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace RequestCodecUnitTest { class MockParserFactory : public InitialParserFactory { public: @@ -164,6 +165,7 @@ template void RequestCodecUnitTest::putInBuffer(T arg) { serializer.encode(arg); } +} // namespace RequestCodecUnitTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index ff2e5de20e67f..bc046448d4ab1 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -4,6 +4,7 @@ namespace Envoy { namespace Extensions { namespace NetworkFilters { namespace Kafka { +namespace SerializationTest { /** * Tests in this file are supposed to check whether serialization operations @@ -215,6 +216,7 @@ TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { EXPECT_THROW(testee.feed(data), EnvoyException); } +} // namespace SerializationTest } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions From d47c2c2d59d921f4feb065584d77833da1a52867 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 28 Mar 2019 12:47:59 -0700 Subject: [PATCH 24/29] Kick CI Signed-off-by: Adam Kotwasinski From 145f28d8a6db00e646725bf8422bf450482df6d5 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 28 Mar 2019 13:50:41 -0700 Subject: [PATCH 25/29] Kick CI Signed-off-by: Adam Kotwasinski From 11d9288c97171374bfda7e26262543ddca3fa8dd Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 28 Mar 2019 16:03:56 -0700 Subject: [PATCH 26/29] Kick CI Signed-off-by: Adam Kotwasinski From 7f1abca1074d84598ae73465a5ddb06e71c0bd3b Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 25 Apr 2019 10:01:54 -0700 Subject: [PATCH 27/29] Put generated files in directories named 'external', so they do not get picked up by gcovr Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 16 ++++++++-------- .../filters/network/kafka/kafka_request.h | 2 +- .../kafka_request_resolver_cc.j2 | 2 +- .../request_codec_request_test_cc.j2 | 2 +- .../protocol_code_generator/requests_test_cc.j2 | 2 +- .../serialization_composite_test_cc.j2 | 2 +- test/extensions/filters/network/kafka/BUILD | 16 ++++++++-------- 7 files changed, 21 insertions(+), 21 deletions(-) diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index f4cc3ce6bb8b4..031b786dabd0b 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -28,13 +28,13 @@ envoy_cc_library( envoy_cc_library( name = "kafka_request_lib", srcs = [ + "external/kafka_request_resolver.cc", "kafka_request_parser.cc", - "kafka_request_resolver.cc", ], hdrs = [ + "external/requests.h", "kafka_request.h", "kafka_request_parser.h", - "requests.h", ], deps = [ ":parser_lib", @@ -50,12 +50,12 @@ genrule( "@kafka_source//:request_protocol_files", ], outs = [ - "requests.h", - "kafka_request_resolver.cc", + "external/requests.h", + "external/kafka_request_resolver.cc", ], cmd = """ ./$(location :kafka_code_generator) generate-source \ - $(location requests.h) $(location kafka_request_resolver.cc) \ + $(location external/requests.h) $(location external/kafka_request_resolver.cc) \ $(SRCS) """, tools = [ @@ -94,8 +94,8 @@ envoy_cc_library( envoy_cc_library( name = "serialization_lib", hdrs = [ + "external/serialization_composite.h", "serialization.h", - "serialization_composite.h", ], deps = [ ":kafka_protocol_lib", @@ -108,11 +108,11 @@ genrule( name = "serialization_composite_generated_source", srcs = [], outs = [ - "serialization_composite.h", + "external/serialization_composite.h", ], cmd = """ ./$(location :serialization_composite_generator) generate-source \ - $(location serialization_composite.h) + $(location external/serialization_composite.h) """, tools = [ ":serialization_composite_generator", diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index 759d0ddd6a7e7..b48c48eb6bf76 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -2,9 +2,9 @@ #include "envoy/common/exception.h" +#include "extensions/filters/network/kafka/external/serialization_composite.h" #include "extensions/filters/network/kafka/message.h" #include "extensions/filters/network/kafka/serialization.h" -#include "extensions/filters/network/kafka/serialization_composite.h" namespace Envoy { namespace Extensions { diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index 9a8de2bf087d4..bdffd5956ca55 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -3,7 +3,7 @@ Defines default Kafka request resolver, that uses request parsers in (also generated) 'requests.h'. #} -#include "extensions/filters/network/kafka/requests.h" +#include "extensions/filters/network/kafka/external/requests.h" #include "extensions/filters/network/kafka/kafka_request_parser.h" #include "extensions/filters/network/kafka/parser.h" diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 index 6d53f01ecfa46..74ab7624b0f5e 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 @@ -9,8 +9,8 @@ - capture messages received in callback, - verify that captured messages are identical to the ones sent. #} +#include "extensions/filters/network/kafka/external/requests.h" #include "extensions/filters/network/kafka/request_codec.h" -#include "extensions/filters/network/kafka/requests.h" #include "test/extensions/filters/network/kafka/serialization_utilities.h" #include "test/mocks/server/mocks.h" diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index e2abbf6b7d979..4ed97ab2a0b37 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -3,7 +3,7 @@ For every request, we want to check if it can be serialized and deserialized properly. #} -#include "extensions/filters/network/kafka/requests.h" +#include "extensions/filters/network/kafka/external/requests.h" #include "extensions/filters/network/kafka/request_codec.h" #include "test/mocks/server/mocks.h" diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 index 86658ccd38803..ee6204171eb64 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_test_cc.j2 @@ -5,7 +5,7 @@ Covers the corner case of 0 delegates, and then uses templating to create tests for 1..N cases. #} -#include "extensions/filters/network/kafka/serialization_composite.h" +#include "extensions/filters/network/kafka/external/serialization_composite.h" #include "test/extensions/filters/network/kafka/serialization_utilities.h" diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index d8f8fc082501d..c7cd341f5192b 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -36,7 +36,7 @@ envoy_extension_cc_test( envoy_extension_cc_test( name = "serialization_composite_test", - srcs = ["serialization_composite_test.cc"], + srcs = ["external/serialization_composite_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ ":serialization_utilities_lib", @@ -48,10 +48,10 @@ envoy_extension_cc_test( genrule( name = "serialization_composite_test_generator", srcs = [], - outs = ["serialization_composite_test.cc"], + outs = ["external/serialization_composite_test.cc"], cmd = """ ./$(location //source/extensions/filters/network/kafka:serialization_composite_generator) \ - generate-test $(location serialization_composite_test.cc) + generate-test $(location external/serialization_composite_test.cc) """, tools = [ "//source/extensions/filters/network/kafka:serialization_composite_generator", @@ -92,7 +92,7 @@ envoy_extension_cc_test( envoy_extension_cc_test( name = "request_codec_request_test", - srcs = ["request_codec_request_test.cc"], + srcs = ["external/request_codec_request_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ ":serialization_utilities_lib", @@ -103,7 +103,7 @@ envoy_extension_cc_test( envoy_extension_cc_test( name = "requests_test", - srcs = ["requests_test.cc"], + srcs = ["external/requests_test.cc"], extension_name = "envoy.filters.network.kafka", deps = [ "//source/extensions/filters/network/kafka:kafka_request_codec_lib", @@ -117,12 +117,12 @@ genrule( "@kafka_source//:request_protocol_files", ], outs = [ - "requests_test.cc", - "request_codec_request_test.cc", + "external/requests_test.cc", + "external/request_codec_request_test.cc", ], cmd = """ ./$(location //source/extensions/filters/network/kafka:kafka_code_generator) generate-test \ - $(location requests_test.cc) $(location request_codec_request_test.cc) \ + $(location external/requests_test.cc) $(location external/request_codec_request_test.cc) \ $(SRCS) """, tools = [ From 36f76c751e852c5f08a03dc04ff0ffe56d043067 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 25 Apr 2019 12:14:30 -0700 Subject: [PATCH 28/29] Add missing test to NullableArrayDeserializer Signed-off-by: Adam Kotwasinski --- test/extensions/filters/network/kafka/serialization_test.cc | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/extensions/filters/network/kafka/serialization_test.cc b/test/extensions/filters/network/kafka/serialization_test.cc index bc046448d4ab1..d52cdff25fe13 100644 --- a/test/extensions/filters/network/kafka/serialization_test.cc +++ b/test/extensions/filters/network/kafka/serialization_test.cc @@ -201,6 +201,12 @@ TEST(NullableArrayDeserializer, ShouldConsumeCorrectAmountOfData) { NullableArrayDeserializer>(value); } +TEST(NullableArrayDeserializer, ShouldConsumeNullArray) { + const NullableArray value = absl::nullopt; + serializeThenDeserializeAndCheckEquality< + NullableArrayDeserializer>(value); +} + TEST(NullableArrayDeserializer, ShouldThrowOnInvalidLength) { // given NullableArrayDeserializer testee; From 98c775279c069d99fddfa9116030a3727b89d5b7 Mon Sep 17 00:00:00 2001 From: Adam Kotwasinski Date: Thu, 25 Apr 2019 13:03:27 -0700 Subject: [PATCH 29/29] Refactoring: - Put request parse failures in separate objects; - Simplify message hierarchy; - Remove message.h and make Encoder/Parser/ParseResponse templated to support Response objects in future Signed-off-by: Adam Kotwasinski --- source/extensions/filters/network/kafka/BUILD | 36 ++++---- .../extensions/filters/network/kafka/codec.h | 11 +-- .../filters/network/kafka/kafka_request.h | 58 +++++++----- .../network/kafka/kafka_request_parser.cc | 24 ++--- .../network/kafka/kafka_request_parser.h | 55 ++++++----- .../filters/network/kafka/message.h | 33 ------- .../extensions/filters/network/kafka/parser.h | 51 +++++++---- .../kafka_request_resolver_cc.j2 | 4 +- .../request_codec_request_test_cc.j2 | 15 +-- .../protocol_code_generator/request_parser.j2 | 8 +- .../requests_test_cc.j2 | 11 ++- .../filters/network/kafka/request_codec.cc | 18 ++-- .../filters/network/kafka/request_codec.h | 61 +++++++------ .../filters/network/kafka/serialization.h | 28 +++--- .../serialization_composite_h.j2 | 8 +- test/extensions/filters/network/kafka/BUILD | 2 +- .../kafka/kafka_request_parser_test.cc | 35 ++++--- .../kafka/request_codec_integration_test.cc | 23 +++-- .../network/kafka/request_codec_unit_test.cc | 91 +++++++++++++------ .../network/kafka/serialization_utilities.cc | 15 ++- .../network/kafka/serialization_utilities.h | 11 ++- 21 files changed, 336 insertions(+), 262 deletions(-) delete mode 100644 source/extensions/filters/network/kafka/message.h diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 031b786dabd0b..73bab5124c8e7 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -19,31 +19,41 @@ envoy_cc_library( "request_codec.h", ], deps = [ - ":kafka_request_lib", - ":message_lib", + ":kafka_request_parser_lib", "//source/common/buffer:buffer_lib", ], ) envoy_cc_library( - name = "kafka_request_lib", + name = "kafka_request_parser_lib", srcs = [ "external/kafka_request_resolver.cc", "kafka_request_parser.cc", ], hdrs = [ "external/requests.h", - "kafka_request.h", "kafka_request_parser.h", ], deps = [ + ":kafka_request_lib", ":parser_lib", - ":serialization_lib", "//source/common/common:assert_lib", "//source/common/common:minimal_logger_lib", ], ) +envoy_cc_library( + name = "kafka_request_lib", + srcs = [ + ], + hdrs = [ + "kafka_request.h", + ], + deps = [ + ":serialization_lib", + ], +) + genrule( name = "kafka_generated_source", srcs = [ @@ -75,22 +85,10 @@ envoy_cc_library( name = "parser_lib", hdrs = ["parser.h"], deps = [ - ":kafka_protocol_lib", - ":message_lib", "//source/common/common:minimal_logger_lib", ], ) -envoy_cc_library( - name = "message_lib", - hdrs = [ - "message.h", - ], - deps = [ - "//include/envoy/buffer:buffer_interface", - ], -) - envoy_cc_library( name = "serialization_lib", hdrs = [ @@ -98,7 +96,7 @@ envoy_cc_library( "serialization.h", ], deps = [ - ":kafka_protocol_lib", + ":kafka_types_lib", "//include/envoy/buffer:buffer_interface", "//source/common/common:byte_order_lib", ], @@ -128,7 +126,7 @@ py_binary( ) envoy_cc_library( - name = "kafka_protocol_lib", + name = "kafka_types_lib", hdrs = [ "kafka_types.h", ], diff --git a/source/extensions/filters/network/kafka/codec.h b/source/extensions/filters/network/kafka/codec.h index 01b9a3c84ad15..a58c284a052a1 100644 --- a/source/extensions/filters/network/kafka/codec.h +++ b/source/extensions/filters/network/kafka/codec.h @@ -3,8 +3,6 @@ #include "envoy/buffer/buffer.h" #include "envoy/common/pure.h" -#include "extensions/filters/network/kafka/message.h" - namespace Envoy { namespace Extensions { namespace NetworkFilters { @@ -19,23 +17,24 @@ class MessageDecoder { /** * Processes given buffer attempting to decode messages contained within. - * @param data buffer instance + * @param data buffer instance. */ virtual void onData(Buffer::Instance& data) PURE; }; /** * Kafka message encoder. + * @param MessageType encoded message type (request or response). */ -class MessageEncoder { +template class MessageEncoder { public: virtual ~MessageEncoder() = default; /** * Encodes given message. - * @param message message to be encoded + * @param message message to be encoded. */ - virtual void encode(const Message& message) PURE; + virtual void encode(const MessageType& message) PURE; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/kafka_request.h b/source/extensions/filters/network/kafka/kafka_request.h index b48c48eb6bf76..e15605515b75a 100644 --- a/source/extensions/filters/network/kafka/kafka_request.h +++ b/source/extensions/filters/network/kafka/kafka_request.h @@ -3,7 +3,6 @@ #include "envoy/common/exception.h" #include "extensions/filters/network/kafka/external/serialization_composite.h" -#include "extensions/filters/network/kafka/message.h" #include "extensions/filters/network/kafka/serialization.h" namespace Envoy { @@ -27,30 +26,60 @@ struct RequestHeader { }; }; +/** + * Carries information that could be extracted during the failed parse. + */ +class RequestParseFailure { +public: + RequestParseFailure(const RequestHeader& request_header) : request_header_{request_header} {}; + + /** + * Request's header. + */ + const RequestHeader request_header_; +}; + +typedef std::shared_ptr RequestParseFailureSharedPtr; + /** * Abstract Kafka request. * Contains data present in every request (the header with request key, version, etc.). * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class AbstractRequest : public Message { +class AbstractRequest { public: + virtual ~AbstractRequest() = default; + + /** + * Constructs a request with given header data. + * @param request_header request's header. + */ AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {}; + /** + * Encode the contents of this message into a given buffer. + * @param dst buffer instance to keep serialized message + */ + virtual size_t encode(Buffer::Instance& dst) const PURE; + /** * Request's header. */ const RequestHeader request_header_; }; +typedef std::shared_ptr AbstractRequestSharedPtr; + /** * Concrete request that carries data particular to given request type. + * @param Data concrete request data type. */ -template class ConcreteRequest : public AbstractRequest { +template class Request : public AbstractRequest { public: /** * Request header fields need to be initialized by user in case of newly created requests. */ - ConcreteRequest(const RequestHeader& request_header, const RequestData& data) + Request(const RequestHeader& request_header, const Data& data) : AbstractRequest{request_header}, data_{data} {}; /** @@ -69,29 +98,12 @@ template class ConcreteRequest : public AbstractRequest { return written; } - bool operator==(const ConcreteRequest& rhs) const { + bool operator==(const Request& rhs) const { return request_header_ == rhs.request_header_ && data_ == rhs.data_; }; private: - const RequestData data_; -}; - -/** - * Request that did not have api_key & api_version that could be matched with any of - * request-specific parsers. - * Right now it acts as a placeholder only, and does not carry the request data. - */ -class UnknownRequest : public AbstractRequest { -public: - UnknownRequest(const RequestHeader& request_header) : AbstractRequest{request_header} {}; - - /** - * It is impossible to encode unknown request, as it is only a placeholder. - */ - size_t encode(Buffer::Instance&) const override { - throw EnvoyException("cannot serialize unknown request"); - } + const Data data_; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.cc b/source/extensions/filters/network/kafka/kafka_request_parser.cc index db45d33a05d83..0245a9bff24b9 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.cc +++ b/source/extensions/filters/network/kafka/kafka_request_parser.cc @@ -9,18 +9,18 @@ const RequestParserResolver& RequestParserResolver::getDefaultInstance() { CONSTRUCT_ON_FIRST_USE(RequestParserResolver); } -ParseResponse RequestStartParser::parse(absl::string_view& data) { +RequestParseResponse RequestStartParser::parse(absl::string_view& data) { request_length_.feed(data); if (request_length_.ready()) { context_->remaining_request_size_ = request_length_.get(); - return ParseResponse::nextParser( + return RequestParseResponse::nextParser( std::make_shared(parser_resolver_, context_)); } else { - return ParseResponse::stillWaiting(); + return RequestParseResponse::stillWaiting(); } } -ParseResponse RequestHeaderParser::parse(absl::string_view& data) { +RequestParseResponse RequestHeaderParser::parse(absl::string_view& data) { const absl::string_view orig_data = data; try { context_->remaining_request_size_ -= deserializer_->feed(data); @@ -30,29 +30,29 @@ ParseResponse RequestHeaderParser::parse(absl::string_view& data) { const int32_t consumed = static_cast(orig_data.size() - data.size()); context_->remaining_request_size_ -= consumed; context_->request_header_ = {-1, -1, -1, absl::nullopt}; - return ParseResponse::nextParser(std::make_shared(context_)); + return RequestParseResponse::nextParser(std::make_shared(context_)); } if (deserializer_->ready()) { RequestHeader request_header = deserializer_->get(); context_->request_header_ = request_header; - ParserSharedPtr next_parser = parser_resolver_.createParser( + RequestParserSharedPtr next_parser = parser_resolver_.createParser( request_header.api_key_, request_header.api_version_, context_); - return ParseResponse::nextParser(next_parser); + return RequestParseResponse::nextParser(next_parser); } else { - return ParseResponse::stillWaiting(); + return RequestParseResponse::stillWaiting(); } } -ParseResponse SentinelParser::parse(absl::string_view& data) { +RequestParseResponse SentinelParser::parse(absl::string_view& data) { const size_t min = std::min(context_->remaining_request_size_, data.size()); data = {data.data() + min, data.size() - min}; context_->remaining_request_size_ -= min; if (0 == context_->remaining_request_size_) { - return ParseResponse::parsedMessage( - std::make_shared(context_->request_header_)); + return RequestParseResponse::parseFailure( + std::make_shared(context_->request_header_)); } else { - return ParseResponse::stillWaiting(); + return RequestParseResponse::stillWaiting(); } } diff --git a/source/extensions/filters/network/kafka/kafka_request_parser.h b/source/extensions/filters/network/kafka/kafka_request_parser.h index 404a37b2a6fff..861d4dc4a3a9d 100644 --- a/source/extensions/filters/network/kafka/kafka_request_parser.h +++ b/source/extensions/filters/network/kafka/kafka_request_parser.h @@ -14,6 +14,10 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { +using RequestParseResponse = ParseResponse; +using RequestParser = Parser; +using RequestParserSharedPtr = std::shared_ptr; + /** * Context that is shared between parsers that are handling the same single message. */ @@ -35,13 +39,13 @@ class RequestParserResolver { /** * Creates a parser that is going to process data specific for given api_key & api_version. - * @param api_key request type - * @param api_version request version - * @param context context to be used by parser - * @return parser that is capable of processing data for given request type & version + * @param api_key request type. + * @param api_version request version. + * @param context context to be used by parser. + * @return parser that is capable of processing data for given request type & version. */ - virtual ParserSharedPtr createParser(int16_t api_key, int16_t api_version, - RequestContextSharedPtr context) const; + virtual RequestParserSharedPtr createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const; /** * Return default resolver, that uses request's api key and version to provide a matching parser. @@ -53,16 +57,16 @@ class RequestParserResolver { * Request parser responsible for consuming request length and setting up context with this data. * @see http://kafka.apache.org/protocol.html#protocol_common */ -class RequestStartParser : public Parser { +class RequestStartParser : public RequestParser { public: RequestStartParser(const RequestParserResolver& parser_resolver) : parser_resolver_{parser_resolver}, context_{std::make_shared()} {}; /** * Consumes 4 bytes (INT32) as request length and updates the context with that value. - * @return RequestHeaderParser instance to process request header + * @return RequestHeaderParser instance to process request header. */ - ParseResponse parse(absl::string_view& data) override; + RequestParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -90,7 +94,7 @@ typedef std::unique_ptr RequestHeaderDeserializerPtr; * parser. * @see http://kafka.apache.org/protocol.html#protocol_messages */ -class RequestHeaderParser : public Parser { +class RequestHeaderParser : public RequestParser { public: // Default constructor. RequestHeaderParser(const RequestParserResolver& parser_resolver, RequestContextSharedPtr context) @@ -107,7 +111,7 @@ class RequestHeaderParser : public Parser { * Uses data provided to compute request header. * @return Parser instance responsible for processing rest of the message */ - ParseResponse parse(absl::string_view& data) override; + RequestParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -122,14 +126,14 @@ class RequestHeaderParser : public Parser { * api_key & api_version. It does not attempt to capture any data, just throws it away until end of * message. */ -class SentinelParser : public Parser { +class SentinelParser : public RequestParser { public: SentinelParser(RequestContextSharedPtr context) : context_{context} {}; /** - * Returns UnknownRequest. Ignores (jumps over) the data provided. + * Returns failed parse data. Ignores (jumps over) the data provided. */ - ParseResponse parse(absl::string_view& data) override; + RequestParseResponse parse(absl::string_view& data) override; const RequestContextSharedPtr contextForTest() const { return context_; } @@ -141,38 +145,39 @@ class SentinelParser : public Parser { * Request parser uses a single deserializer to construct a request object. * This parser is responsible for consuming request-specific data (e.g. topic names) and always * returns a parsed message. - * @param RequestType request class + * @param RequestType request class. * @param DeserializerType deserializer type corresponding to request class (should be subclass of - * Deserializer) + * Deserializer). */ -template class RequestParser : public Parser { +template +class RequestDataParser : public RequestParser { public: /** * Create a parser with given context. - * @param context parse context containing request header + * @param context parse context containing request header. */ - RequestParser(RequestContextSharedPtr context) : context_{context} {}; + RequestDataParser(RequestContextSharedPtr context) : context_{context} {}; /** * Consume enough data to fill in deserializer and receive the parsed request. * Fill in request's header with data stored in context. */ - ParseResponse parse(absl::string_view& data) override { + RequestParseResponse parse(absl::string_view& data) override { context_->remaining_request_size_ -= deserializer.feed(data); if (deserializer.ready()) { if (0 == context_->remaining_request_size_) { // After a successful parse, there should be nothing left - we have consumed all the bytes. - MessageSharedPtr msg = std::make_shared>( - context_->request_header_, deserializer.get()); - return ParseResponse::parsedMessage(msg); + AbstractRequestSharedPtr msg = + std::make_shared>(context_->request_header_, deserializer.get()); + return RequestParseResponse::parsedMessage(msg); } else { // The message makes no sense, the deserializer that matches the schema consumed all // necessary data, but there are still bytes in this message. - return ParseResponse::nextParser(std::make_shared(context_)); + return RequestParseResponse::nextParser(std::make_shared(context_)); } } else { - return ParseResponse::stillWaiting(); + return RequestParseResponse::stillWaiting(); } } diff --git a/source/extensions/filters/network/kafka/message.h b/source/extensions/filters/network/kafka/message.h deleted file mode 100644 index e6747ad7f453c..0000000000000 --- a/source/extensions/filters/network/kafka/message.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#include -#include - -#include "envoy/buffer/buffer.h" -#include "envoy/common/pure.h" - -namespace Envoy { -namespace Extensions { -namespace NetworkFilters { -namespace Kafka { - -/** - * Abstract message (that can be either request or response). - */ -class Message { -public: - virtual ~Message() = default; - - /** - * Encode the contents of this message into a given buffer. - * @param dst buffer instance to keep serialized message - */ - virtual size_t encode(Buffer::Instance& dst) const PURE; -}; - -typedef std::shared_ptr MessageSharedPtr; - -} // namespace Kafka -} // namespace NetworkFilters -} // namespace Extensions -} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/parser.h b/source/extensions/filters/network/kafka/parser.h index 2ef06dbc2b1fc..031aaaef1dbd3 100644 --- a/source/extensions/filters/network/kafka/parser.h +++ b/source/extensions/filters/network/kafka/parser.h @@ -4,9 +4,6 @@ #include "common/common/logger.h" -#include "extensions/filters/network/kafka/kafka_types.h" -#include "extensions/filters/network/kafka/message.h" - #include "absl/strings/string_view.h" namespace Envoy { @@ -14,12 +11,13 @@ namespace Extensions { namespace NetworkFilters { namespace Kafka { -class ParseResponse; +template class ParseResponse; /** * Parser is responsible for consuming data relevant to some part of a message, and then returning * the decision how the parsing should continue. */ +template class Parser : public Logger::Loggable { public: virtual ~Parser() = default; @@ -27,52 +25,67 @@ class Parser : public Logger::Loggable { /** * Submit data to be processed by parser, will consume as much data as it is necessary to reach * the conclusion what should be the next parse step. - * @param data bytes to be processed, will be updated by parser if any have been consumed - * @return parse status - decision what should be done with current parser (keep/replace) + * @param data bytes to be processed, will be updated by parser if any have been consumed. + * @return parse status - decision what should be done with current parser (keep/replace). */ - virtual ParseResponse parse(absl::string_view& data) PURE; + virtual ParseResponse parse(absl::string_view& data) PURE; }; -typedef std::shared_ptr ParserSharedPtr; +template +using ParserSharedPtr = std::shared_ptr>; /** * Three-state holder representing one of: * - parser still needs data (`stillWaiting`), * - parser is finished, and following parser should be used to process the rest of data * (`nextParser`), - * - parser is finished, and fully-parsed message is attached (`parsedMessage`). + * - parser is finished, and parse result is attached (`parsedMessage` or `parseFailure`). */ -class ParseResponse { +template class ParseResponse { public: /** * Constructs a response that states that parser still needs data and should not be replaced. */ - static ParseResponse stillWaiting() { return {nullptr, nullptr}; } + static ParseResponse stillWaiting() { return {nullptr, nullptr, nullptr}; } /** * Constructs a response that states that parser is finished and should be replaced by given * parser. */ - static ParseResponse nextParser(ParserSharedPtr next_parser) { return {next_parser, nullptr}; }; + static ParseResponse nextParser(ParserSharedPtr next_parser) { + return {next_parser, nullptr, nullptr}; + }; /** * Constructs a response that states that parser is finished, the message is ready, and parsing * can start anew for next message. */ - static ParseResponse parsedMessage(MessageSharedPtr message) { return {nullptr, message}; }; + static ParseResponse parsedMessage(MessageType message) { return {nullptr, message, nullptr}; }; + + /** + * Constructs a response that states that parser is finished, the message could not be parsed + * properly, and parsing can start anew for next message. + */ + static ParseResponse parseFailure(FailureDataType failure_data) { + return {nullptr, nullptr, failure_data}; + }; /** - * If response contains a next parser or the fully parsed message. + * If response contains a next parser or a parse result. */ - bool hasData() const { return (next_parser_ != nullptr) || (message_ != nullptr); } + bool hasData() const { + return (next_parser_ != nullptr) || (message_ != nullptr) || (failure_data_ != nullptr); + } private: - ParseResponse(ParserSharedPtr parser, MessageSharedPtr message) - : next_parser_{parser}, message_{message} {}; + ParseResponse(ParserSharedPtr parser, MessageType message, + FailureDataType failure_data) + : next_parser_{parser}, message_{message}, failure_data_{failure_data} {}; public: - ParserSharedPtr next_parser_; - MessageSharedPtr message_; + ParserSharedPtr next_parser_; + MessageType message_; + FailureDataType failure_data_; }; } // namespace Kafka diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 index bdffd5956ca55..d73f76955adca 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/kafka_request_resolver_cc.j2 @@ -20,8 +20,8 @@ namespace Kafka { * @param api_version Kafka request's version * @param context parse context */ -ParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, - RequestContextSharedPtr context) const { +RequestParserSharedPtr RequestParserResolver::createParser(int16_t api_key, int16_t api_version, + RequestContextSharedPtr context) const { {% for request_type in request_types %}{% for field_list in request_type.compute_field_lists() %} if ({{ request_type.get_extra('api_key') }} == api_key diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 index 74ab7624b0f5e..c853563f8f8a9 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_codec_request_test_cc.j2 @@ -36,9 +36,9 @@ protected: TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { // given - using Request = ConcreteRequest<{{ request_type.name }}>; + using RequestUnderTest = Request<{{ request_type.name }}>; - std::vector sent; + std::vector sent; int32_t correlation = 0; {% for field_list in request_type.compute_field_lists() %} @@ -46,7 +46,7 @@ TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { const RequestHeader header = { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, correlation++, "id" }; const {{ request_type.name }} data = { {{ field_list.example_value() }} }; - const Request request = {header, data}; + const RequestUnderTest request = {header, data}; putInBuffer(request); sent.push_back(request); } @@ -64,11 +64,12 @@ TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { testee.onData(buffer_); // then - const std::vector& received = request_callback->getCaptured(); + const std::vector& received = request_callback->getCaptured(); ASSERT_EQ(received.size(), sent.size()); for (size_t i = 0; i < received.size(); ++i) { - const std::shared_ptr request = std::dynamic_pointer_cast(received[i]); + const std::shared_ptr request = + std::dynamic_pointer_cast(received[i]); ASSERT_NE(request, nullptr); ASSERT_EQ(*request, sent[i]); } @@ -77,8 +78,8 @@ TEST_F(RequestCodecRequestTest, shouldHandle{{ request_type.name }}Messages) { template void RequestCodecRequestTest::putInBuffer(const T arg) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(arg); + RequestEncoder encoder{buffer_}; + encoder.encode(arg); } } // namespace RequestCodecRequestTest diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 index 33f67b3f85f83..db14f9e2a55cf 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/request_parser.j2 @@ -3,16 +3,18 @@ Rendered templates for each request in Kafka protocol will be put into 'requests.h' file. This template handles binding the top-level structure deserializer - (e.g. ProduceRequestV0Deserializer) with RequestParser. These parsers are then used by + (e.g. ProduceRequestV0Deserializer) with RequestDataParser. These parsers are then used by RequestParserResolver instance depending on received Kafka api key & api version (see 'kafka_request_resolver_cc.j2'). #} {% for version in complex_type.versions %}class {{ complex_type.name }}V{{ version }}Parser: - public RequestParser<{{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> + public RequestDataParser< + {{ complex_type.name }}, {{ complex_type.name }}V{{ version }}Deserializer> { public: - {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : RequestParser{ctx} {}; + {{ complex_type.name }}V{{ version }}Parser(RequestContextSharedPtr ctx) : + RequestDataParser{ctx} {}; }; {% endfor %} \ No newline at end of file diff --git a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 index 4ed97ab2a0b37..d7ec7ae98ca4f 100644 --- a/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 +++ b/source/extensions/filters/network/kafka/protocol_code_generator/requests_test_cc.j2 @@ -26,7 +26,8 @@ public: class MockMessageListener : public RequestCallback { public: - MOCK_METHOD1(onMessage, void(MessageSharedPtr)); + MOCK_METHOD1(onMessage, void(AbstractRequestSharedPtr)); + MOCK_METHOD1(onFailedParse, void(RequestParseFailureSharedPtr)); }; /** @@ -35,13 +36,13 @@ public: * This method gets executed for every request * version pair. */ template std::shared_ptr RequestTest::serializeAndDeserialize(T request) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(request); + RequestEncoder encoder{buffer_}; + encoder.encode(request); std::shared_ptr mock_listener = std::make_shared(); RequestDecoder testee{RequestParserResolver::getDefaultInstance(), {mock_listener}}; - MessageSharedPtr receivedMessage; + AbstractRequestSharedPtr receivedMessage; EXPECT_CALL(*mock_listener, onMessage(testing::_)) .WillOnce(testing::SaveArg<0>(&receivedMessage)); @@ -59,7 +60,7 @@ template std::shared_ptr RequestTest::serializeAndDeserialize(T TEST_F(RequestTest, shouldParse{{ request_type.name }}V{{ field_list.version }}) { // given {{ request_type.name }} data = { {{ field_list.example_value() }} }; - ConcreteRequest<{{ request_type.name }}> request = { { + Request<{{ request_type.name }}> request = { { {{ request_type.get_extra('api_key') }}, {{ field_list.version }}, 0, absl::nullopt }, data }; // when diff --git a/source/extensions/filters/network/kafka/request_codec.cc b/source/extensions/filters/network/kafka/request_codec.cc index 49cf6c1afabb9..7b46c77a0f76a 100644 --- a/source/extensions/filters/network/kafka/request_codec.cc +++ b/source/extensions/filters/network/kafka/request_codec.cc @@ -11,7 +11,7 @@ namespace NetworkFilters { namespace Kafka { class RequestStartParserFactory : public InitialParserFactory { - ParserSharedPtr create(const RequestParserResolver& parser_resolver) const override { + RequestParserSharedPtr create(const RequestParserResolver& parser_resolver) const override { return std::make_shared(parser_resolver); } }; @@ -48,15 +48,21 @@ void RequestDecoder::doParse(const Buffer::RawSlice& slice) { while (!data.empty()) { // Feed the data to the parser. - ParseResponse result = current_parser_->parse(data); + RequestParseResponse result = current_parser_->parse(data); // This loop guarantees that parsers consuming 0 bytes also get processed in this invocation. while (result.hasData()) { if (!result.next_parser_) { // Next parser is not present, so we have finished parsing a message. - MessageSharedPtr message = result.message_; - for (auto& callback : callbacks_) { - callback->onMessage(result.message_); + // Depending on whether the parse was successful, invoke the correct callback. + if (result.message_) { + for (auto& callback : callbacks_) { + callback->onMessage(result.message_); + } + } else { + for (auto& callback : callbacks_) { + callback->onFailedParse(result.failure_data_); + } } // As we finished parsing this request, re-initialize the parser. @@ -73,7 +79,7 @@ void RequestDecoder::doParse(const Buffer::RawSlice& slice) { } } -void MessageEncoderImpl::encode(const Message& message) { +void RequestEncoder::encode(const AbstractRequest& message) { Buffer::OwnedImpl data_buffer; // TODO(adamkotwasinski) Precompute the size instead of using temporary buffer. // When we have the 'computeSize' method, then we can push encoding request's size into diff --git a/source/extensions/filters/network/kafka/request_codec.h b/source/extensions/filters/network/kafka/request_codec.h index 7ddbb2f1417fd..c8a6b69f87973 100644 --- a/source/extensions/filters/network/kafka/request_codec.h +++ b/source/extensions/filters/network/kafka/request_codec.h @@ -22,55 +22,60 @@ class RequestCallback { /** * Callback method invoked when request is successfully decoded. - * @param request request that has been decoded + * @param request request that has been decoded. */ - virtual void onMessage(MessageSharedPtr request) PURE; + virtual void onMessage(AbstractRequestSharedPtr request) PURE; + + /** + * Callback method invoked when request could not be decoded. + * Invoked after all request's bytes have been consumed. + */ + virtual void onFailedParse(RequestParseFailureSharedPtr failure_data) PURE; }; typedef std::shared_ptr RequestCallbackSharedPtr; /** - * Provides initial parser for messages - * (class extracted to allow injecting test factories) + * Provides initial parser for messages (class extracted to allow injecting test factories). */ class InitialParserFactory { public: virtual ~InitialParserFactory() = default; /** - * Creates default instance that returns RequestStartParser instances + * Creates default instance that returns RequestStartParser instances. */ static const InitialParserFactory& getDefaultInstance(); /** - * Creates parser with given context + * Creates parser with given context. */ - virtual ParserSharedPtr create(const RequestParserResolver& parser_resolver) const PURE; + virtual RequestParserSharedPtr create(const RequestParserResolver& parser_resolver) const PURE; }; /** - * Decoder that decodes Kafka requests - * When a request is decoded, the callbacks are notified, in order + * Decoder that decodes Kafka requests. + * When a request is decoded, the callbacks are notified, in order. * - * This decoder uses chain of parsers to parse fragments of a request - * Each parser along the line returns the fully parsed message or the next parser - * Stores parse state (as large message's payload can be provided through multiple `onData` calls) + * This decoder uses chain of parsers to parse fragments of a request. + * Each parser along the line returns the fully parsed message or the next parser. + * Stores parse state (as large message's payload can be provided through multiple `onData` calls). */ class RequestDecoder : public MessageDecoder { public: /** * Creates a decoder that can decode requests specified by RequestParserResolver, notifying - * callbacks on successful decoding - * @param parserResolver supported parser resolver - * @param callbacks callbacks to be invoked (in order) + * callbacks on successful decoding. + * @param parserResolver supported parser resolver. + * @param callbacks callbacks to be invoked (in order). */ RequestDecoder(const RequestParserResolver& parserResolver, const std::vector callbacks) : RequestDecoder(InitialParserFactory::getDefaultInstance(), parserResolver, callbacks){}; /** - * Visible for testing - * Allows injecting initial parser factory + * Visible for testing. + * Allows injecting initial parser factory. */ RequestDecoder(const InitialParserFactory& factory, const RequestParserResolver& parserResolver, const std::vector callbacks) @@ -78,10 +83,10 @@ class RequestDecoder : public MessageDecoder { current_parser_{factory_.create(parser_resolver_)} {}; /** - * Consumes all data present in a buffer - * If a request can be successfully parsed, then callbacks get notified with parsed request - * Updates decoder state - * impl note: similar to redis codec, which also keeps state + * Consumes all data present in a buffer. + * If a request can be successfully parsed, then callbacks get notified with parsed request. + * Updates decoder state. + * Impl note: similar to redis codec, which also keeps state. */ void onData(Buffer::Instance& data) override; @@ -94,23 +99,23 @@ class RequestDecoder : public MessageDecoder { const std::vector callbacks_; - ParserSharedPtr current_parser_; + RequestParserSharedPtr current_parser_; }; /** - * Encodes provided messages into underlying buffer + * Encodes requests into underlying buffer. */ -class MessageEncoderImpl : public MessageEncoder { +class RequestEncoder : public MessageEncoder { public: /** - * Wraps buffer with encoder + * Wraps buffer with encoder. */ - MessageEncoderImpl(Buffer::Instance& output) : output_(output) {} + RequestEncoder(Buffer::Instance& output) : output_(output) {} /** - * Encodes request into wrapped buffer + * Encodes request into wrapped buffer. */ - void encode(const Message& message) override; + void encode(const AbstractRequest& message) override; private: Buffer::Instance& output_; diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 661c84d03da58..38542c44d8761 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -27,7 +27,7 @@ namespace Kafka { * When ready(), it is safe to call get() to transform the internally stored bytes into result. * Further feed()-ing should have no effect on a buffer (should return 0 and not move * provided pointer). - * @param T type of deserialized data + * @param T type of deserialized data. */ template class Deserializer { public: @@ -37,8 +37,8 @@ template class Deserializer { * Submit data to be processed, will consume as much data as it is necessary. * If any bytes are consumed, then the provided string view is updated by stepping over consumed * bytes. Invoking this method when deserializer is ready has no effect (consumes 0 bytes). - * @param data bytes to be processed, will be updated if any have been consumed - * @return number of bytes consumed (equal to change in 'data') + * @param data bytes to be processed, will be updated if any have been consumed. + * @return number of bytes consumed (equal to change in 'data'). */ virtual size_t feed(absl::string_view& data) PURE; @@ -145,7 +145,7 @@ class Int64Deserializer : public IntDeserializer { }; /** - * Deserializer for boolean values + * Deserializer for boolean values. * Uses a single int8 deserializer, and checks whether the results equals 0. * When reading a boolean value, any non-zero value is considered true. * Impl note: could have been a subclass of IntDeserializer with a different get function, @@ -169,7 +169,7 @@ class BooleanDeserializer : public Deserializer { * Deserializer of string value. * First reads length (INT16) and then allocates the buffer of given length. * - * From documentation: + * From Kafka documentation: * First the length N is given as an INT16. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * Length must not be negative. @@ -230,7 +230,7 @@ class StringDeserializer : public Deserializer { * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning * null value). * - * From documentation: + * From Kafka documentation: * For non-null strings, first the length N is given as an INT16. * Then N bytes follow which are the UTF-8 encoding of the character sequence. * A null value is encoded with length of -1 and there are no following bytes. @@ -304,7 +304,7 @@ class NullableStringDeserializer : public Deserializer { * Deserializer of bytes value. * First reads length (INT32) and then allocates the buffer of given length. * - * From documentation: + * From Kafka documentation: * First the length N is given as an INT32. Then N bytes follow. */ class BytesDeserializer : public Deserializer { @@ -362,7 +362,7 @@ class BytesDeserializer : public Deserializer { * If length was -1, buffer allocation is omitted and deserializer is immediately ready (returning * null value). * - * From documentation: + * From Kafka documentation: * For non-null values, first the length N is given as an INT32. Then N bytes follow. * A null value is encoded with length of -1 and there are no following bytes. */ @@ -439,10 +439,10 @@ class NullableBytesDeserializer : public Deserializer { * First reads the length of the array, then initializes N underlying deserializers of type * DeserializerType. After the last of N deserializers is ready, the results of each of them are * gathered and put in a vector. - * @param ResponseType result type returned by deserializer of type DeserializerType - * @param DeserializerType underlying deserializer type + * @param ResponseType result type returned by deserializer of type DeserializerType. + * @param DeserializerType underlying deserializer type. * - * From documentation: + * From Kafka documentation: * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T * follow. A null array is represented with a length of -1. @@ -516,10 +516,10 @@ class ArrayDeserializer : public Deserializer> { * First reads the length of the array, then initializes N underlying deserializers of type * DeserializerType. After the last of N deserializers is ready, the results of each of them are * gathered and put in a vector. - * @param ResponseType result type returned by deserializer of type DeserializerType - * @param DeserializerType underlying deserializer type + * @param ResponseType result type returned by deserializer of type DeserializerType. + * @param DeserializerType underlying deserializer type. * - * From documentation: + * From Kafka documentation: * Represents a sequence of objects of a given type T. Type T can be either a primitive type (e.g. * STRING) or a structure. First, the length N is given as an int32_t. Then N instances of type T * follow. A null array is represented with a length of -1. diff --git a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 index 271d0084332f5..5b88fa9538bed 100644 --- a/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 +++ b/source/extensions/filters/network/kafka/serialization_code_generator/serialization_composite_h.j2 @@ -40,7 +40,7 @@ namespace Kafka { * Composite deserializer that uses 0 deserializer(s) (corner case). * Does not consume any bytes, and is always ready to return the result. * Creates a result value using the no-arg ResponseType constructor. - * @param ResponseType type of deserialized data + * @param ResponseType type of deserialized data. */ template class CompositeDeserializerWith0Delegates : public Deserializer { @@ -58,10 +58,10 @@ public: * consume data, so it's safe). * The composite deserializer is ready when the last deserializer is ready (what means that all * deserializers before it are ready too). - * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... } + * Constructs the result of type ResponseType using { delegate1_.get(), delegate2_.get() ... }. * - * @param ResponseType type of deserialized data{% for field in range(1, field_count + 1) %} - * @param DeserializerType{{ field }} deserializer {{ field }} + * @param ResponseType type of deserialized data{% for field in range(1, field_count + 1) %}. + * @param DeserializerType{{ field }} deserializer {{ field }}. {% endfor %} */ template < typename ResponseType{% for field in range(1, field_count + 1) %}, diff --git a/test/extensions/filters/network/kafka/BUILD b/test/extensions/filters/network/kafka/BUILD index c7cd341f5192b..9bb39da434071 100644 --- a/test/extensions/filters/network/kafka/BUILD +++ b/test/extensions/filters/network/kafka/BUILD @@ -64,7 +64,7 @@ envoy_extension_cc_test( extension_name = "envoy.filters.network.kafka", deps = [ ":serialization_utilities_lib", - "//source/extensions/filters/network/kafka:kafka_request_lib", + "//source/extensions/filters/network/kafka:kafka_request_parser_lib", "//test/mocks/server:server_mocks", ], ) diff --git a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc index c23c392840d57..e7163e0fab865 100644 --- a/test/extensions/filters/network/kafka/kafka_request_parser_test.cc +++ b/test/extensions/filters/network/kafka/kafka_request_parser_test.cc @@ -42,7 +42,8 @@ class KafkaRequestParserTest : public testing::Test { class MockRequestParserResolver : public RequestParserResolver { public: MockRequestParserResolver(){}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); + MOCK_CONST_METHOD3(createParser, + RequestParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); }; TEST_F(KafkaRequestParserTest, RequestStartParserTestShouldReturnRequestHeaderParser) { @@ -57,19 +58,20 @@ TEST_F(KafkaRequestParserTest, RequestStartParserTestShouldReturnRequestHeaderPa absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(data); + const RequestParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_len); assertStringViewIncrement(data, orig_data, sizeof(int32_t)); } -class MockParser : public Parser { +class MockParser : public RequestParser { public: - ParseResponse parse(absl::string_view&) override { + RequestParseResponse parse(absl::string_view&) override { throw new EnvoyException("should not be invoked"); } }; @@ -77,7 +79,7 @@ class MockParser : public Parser { TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldExtractHeaderAndResolveNextParser) { // given const MockRequestParserResolver parser_resolver; - const ParserSharedPtr parser{new MockParser{}}; + const RequestParserSharedPtr parser{new MockParser{}}; EXPECT_CALL(parser_resolver, createParser(_, _, _)).WillOnce(Return(parser)); const int32_t request_len = 1000; @@ -99,12 +101,13 @@ TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldExtractHeaderAndResolveN absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(data); + const RequestParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_EQ(result.next_parser_, parser); ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); const RequestHeader expected_header{api_key, api_version, correlation_id, client_id}; ASSERT_EQ(testee.contextForTest()->request_header_, expected_header); @@ -143,12 +146,13 @@ TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldHandleExceptionsDuringFe absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(data); + const RequestParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_size - FAILED_DESERIALIZER_STEP); @@ -156,7 +160,7 @@ TEST_F(KafkaRequestParserTest, RequestHeaderParserShouldHandleExceptionsDuringFe assertStringViewIncrement(data, orig_data, FAILED_DESERIALIZER_STEP); } -TEST_F(KafkaRequestParserTest, RequestParserShouldHandleDeserializerExceptionsDuringFeeding) { +TEST_F(KafkaRequestParserTest, RequestDataParserShouldHandleDeserializerExceptionsDuringFeeding) { // given // This deserializer throws during feeding. @@ -173,7 +177,7 @@ TEST_F(KafkaRequestParserTest, RequestParserShouldHandleDeserializerExceptionsDu }; RequestContextSharedPtr request_context{new RequestContext{1024, {}}}; - RequestParser testee{request_context}; + RequestDataParser testee{request_context}; absl::string_view data = putGarbageIntoBuffer(); @@ -202,23 +206,25 @@ class SomeBytesDeserializer : public Deserializer { int32_t get() const override { return 0; }; }; -TEST_F(KafkaRequestParserTest, RequestParserShouldHandleDeserializerReturningReadyButLeavingData) { +TEST_F(KafkaRequestParserTest, + RequestDataParserShouldHandleDeserializerReturningReadyButLeavingData) { // given const int32_t request_size = 1024; // There are still 1024 bytes to read to complete the request. RequestContextSharedPtr request_context{new RequestContext{request_size, {}}}; - RequestParser testee{request_context}; + RequestDataParser testee{request_context}; const absl::string_view orig_data = putGarbageIntoBuffer(); absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(data); + const RequestParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_NE(std::dynamic_pointer_cast(result.next_parser_), nullptr); ASSERT_EQ(result.message_, nullptr); + ASSERT_EQ(result.failure_data_, nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, request_size - FAILED_DESERIALIZER_STEP); @@ -237,12 +243,13 @@ TEST_F(KafkaRequestParserTest, SentinelParserShouldConsumeDataUntilEndOfRequest) absl::string_view data = orig_data; // when - const ParseResponse result = testee.parse(data); + const RequestParseResponse result = testee.parse(data); // then ASSERT_EQ(result.hasData(), true); ASSERT_EQ(result.next_parser_, nullptr); - ASSERT_NE(std::dynamic_pointer_cast(result.message_), nullptr); + ASSERT_EQ(result.message_, nullptr); + ASSERT_NE(std::dynamic_pointer_cast(result.failure_data_), nullptr); ASSERT_EQ(testee.contextForTest()->remaining_request_size_, 0); diff --git a/test/extensions/filters/network/kafka/request_codec_integration_test.cc b/test/extensions/filters/network/kafka/request_codec_integration_test.cc index b82042ebf8635..6d087f310466a 100644 --- a/test/extensions/filters/network/kafka/request_codec_integration_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_integration_test.cc @@ -29,7 +29,7 @@ TEST_F(RequestCodecIntegrationTest, shouldProduceAbortedMessageOnUnknownData) { const int16_t api_key = static_cast(base_api_key + i); const RequestHeader header = {api_key, 0, 0, "client-id"}; const std::vector data = std::vector(1024); - putInBuffer(ConcreteRequest>{header, data}); + putInBuffer(Request>{header, data}); sent_headers.push_back(header); } @@ -45,21 +45,24 @@ TEST_F(RequestCodecIntegrationTest, shouldProduceAbortedMessageOnUnknownData) { testee.onData(buffer_); // then - const std::vector& received = request_callback->getCaptured(); - ASSERT_EQ(received.size(), sent_headers.size()); + ASSERT_EQ(request_callback->getCaptured().size(), 0); - for (size_t i = 0; i < received.size(); ++i) { - const std::shared_ptr request = - std::dynamic_pointer_cast(received[i]); - ASSERT_NE(request, nullptr); - ASSERT_EQ(request->request_header_, sent_headers[i]); + const std::vector& parse_failures = + request_callback->getParseFailures(); + ASSERT_EQ(parse_failures.size(), sent_headers.size()); + + for (size_t i = 0; i < parse_failures.size(); ++i) { + const std::shared_ptr failure_data = + std::dynamic_pointer_cast(parse_failures[i]); + ASSERT_NE(failure_data, nullptr); + ASSERT_EQ(failure_data->request_header_, sent_headers[i]); } } // Helper function. template void RequestCodecIntegrationTest::putInBuffer(T arg) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(arg); + RequestEncoder encoder{buffer_}; + encoder.encode(arg); } } // namespace RequestCodecIntegrationTest diff --git a/test/extensions/filters/network/kafka/request_codec_unit_test.cc b/test/extensions/filters/network/kafka/request_codec_unit_test.cc index 39f9256ae660b..3685f019759e3 100644 --- a/test/extensions/filters/network/kafka/request_codec_unit_test.cc +++ b/test/extensions/filters/network/kafka/request_codec_unit_test.cc @@ -20,12 +20,12 @@ namespace RequestCodecUnitTest { class MockParserFactory : public InitialParserFactory { public: - MOCK_CONST_METHOD1(create, ParserSharedPtr(const RequestParserResolver&)); + MOCK_CONST_METHOD1(create, RequestParserSharedPtr(const RequestParserResolver&)); }; -class MockParser : public Parser { +class MockParser : public RequestParser { public: - MOCK_METHOD1(parse, ParseResponse(absl::string_view&)); + MOCK_METHOD1(parse, RequestParseResponse(absl::string_view&)); }; typedef std::shared_ptr MockParserSharedPtr; @@ -33,12 +33,14 @@ typedef std::shared_ptr MockParserSharedPtr; class MockRequestParserResolver : public RequestParserResolver { public: MockRequestParserResolver() : RequestParserResolver({}){}; - MOCK_CONST_METHOD3(createParser, ParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); + MOCK_CONST_METHOD3(createParser, + RequestParserSharedPtr(int16_t, int16_t, RequestContextSharedPtr)); }; class MockRequestCallback : public RequestCallback { public: - MOCK_METHOD1(onMessage, void(MessageSharedPtr)); + MOCK_METHOD1(onMessage, void(AbstractRequestSharedPtr)); + MOCK_METHOD1(onFailedParse, void(RequestParseFailureSharedPtr)); }; typedef std::shared_ptr MockRequestCallbackSharedPtr; @@ -54,14 +56,14 @@ class RequestCodecUnitTest : public testing::Test { MockRequestCallbackSharedPtr request_callback_{std::make_shared()}; }; -ParseResponse consumeOneByte(absl::string_view& data) { +RequestParseResponse consumeOneByte(absl::string_view& data) { data = {data.data() + 1, data.size() - 1}; - return ParseResponse::stillWaiting(); + return RequestParseResponse::stillWaiting(); } TEST_F(RequestCodecUnitTest, shouldDoNothingIfParserNeverReturnsMessage) { // given - putInBuffer(ConcreteRequest{{}, 0}); + putInBuffer(Request{{}, 0}); MockParserSharedPtr parser = std::make_shared(); EXPECT_CALL(*parser, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); @@ -79,13 +81,13 @@ TEST_F(RequestCodecUnitTest, shouldDoNothingIfParserNeverReturnsMessage) { TEST_F(RequestCodecUnitTest, shouldUseNewParserAsResponse) { // given - putInBuffer(ConcreteRequest{{}, 0}); + putInBuffer(Request{{}, 0}); MockParserSharedPtr parser1 = std::make_shared(); MockParserSharedPtr parser2 = std::make_shared(); MockParserSharedPtr parser3 = std::make_shared(); - EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(ParseResponse::nextParser(parser2))); - EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(ParseResponse::nextParser(parser3))); + EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(RequestParseResponse::nextParser(parser2))); + EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(RequestParseResponse::nextParser(parser3))); EXPECT_CALL(*parser3, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); EXPECT_CALL(initial_parser_factory_, create(_)).WillOnce(Return(parser1)); @@ -99,13 +101,15 @@ TEST_F(RequestCodecUnitTest, shouldUseNewParserAsResponse) { // There were no interactions with `request_callback`. } -TEST_F(RequestCodecUnitTest, shouldReturnParsedMessageAndReinitialize) { +TEST_F(RequestCodecUnitTest, shouldPassParsedMessageToCallbackAndReinitialize) { // given - putInBuffer(ConcreteRequest{{}, 0}); + putInBuffer(Request{{}, 0}); MockParserSharedPtr parser1 = std::make_shared(); - MessageSharedPtr message = std::make_shared(RequestHeader{}); - EXPECT_CALL(*parser1, parse(_)).WillOnce(Return(ParseResponse::parsedMessage(message))); + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader()); + EXPECT_CALL(*parser1, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); MockParserSharedPtr parser2 = std::make_shared(); EXPECT_CALL(*parser2, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); @@ -114,7 +118,7 @@ TEST_F(RequestCodecUnitTest, shouldReturnParsedMessageAndReinitialize) { .WillOnce(Return(parser1)) .WillOnce(Return(parser2)); - EXPECT_CALL(*request_callback_, onMessage(message)); + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; @@ -125,29 +129,64 @@ TEST_F(RequestCodecUnitTest, shouldReturnParsedMessageAndReinitialize) { // There was only one message sent to `request_callback`. } +TEST_F(RequestCodecUnitTest, shouldPassParseFailureDataToCallbackAndReinitialize) { + // given + putInBuffer(Request{{}, 0}); + + MockParserSharedPtr parser1 = std::make_shared(); + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader()); + EXPECT_CALL(*parser1, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); + + MockParserSharedPtr parser2 = std::make_shared(); + EXPECT_CALL(*parser2, parse(_)).Times(AnyNumber()).WillRepeatedly(Invoke(consumeOneByte)); + + EXPECT_CALL(initial_parser_factory_, create(_)) + .WillOnce(Return(parser1)) + .WillOnce(Return(parser2)); + + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); + + RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; + + // when + testee.onData(buffer_); + + // then + // `request_callback` had `onFailedParse` invoked once with matching argument. +} + TEST_F(RequestCodecUnitTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) { // given - putInBuffer(ConcreteRequest{{}, 0}); + putInBuffer(Request{{}, 0}); MockParserSharedPtr parser1 = std::make_shared(); MockParserSharedPtr parser2 = std::make_shared(); MockParserSharedPtr parser3 = std::make_shared(); - auto consume_and_return = [this, &parser2](absl::string_view& data) -> ParseResponse { + // parser1 consumes buffer_.length() bytes (== everything) and returns parser2 + auto consume_and_return = [this, &parser2](absl::string_view& data) -> RequestParseResponse { data = {data.data() + buffer_.length(), data.size() - buffer_.length()}; - return ParseResponse::nextParser(parser2); + return RequestParseResponse::nextParser(parser2); }; EXPECT_CALL(*parser1, parse(_)).WillOnce(Invoke(consume_and_return)); - MessageSharedPtr message = std::make_shared(RequestHeader{}); - EXPECT_CALL(*parser2, parse(_)).WillOnce(Return(ParseResponse::parsedMessage(message))); + + // parser2 just returns parse result + RequestParseFailureSharedPtr failure_data = + std::make_shared(RequestHeader{}); + EXPECT_CALL(*parser2, parse(_)) + .WillOnce(Return(RequestParseResponse::parseFailure(failure_data))); + + // parser3 just consumes everything EXPECT_CALL(*parser3, parse(ResultOf([](absl::string_view arg) { return arg.size(); }, Eq(0)))) - .WillOnce(Return(ParseResponse::stillWaiting())); + .WillOnce(Return(RequestParseResponse::stillWaiting())); EXPECT_CALL(initial_parser_factory_, create(_)) .WillOnce(Return(parser1)) .WillOnce(Return(parser3)); - EXPECT_CALL(*request_callback_, onMessage(message)); + EXPECT_CALL(*request_callback_, onFailedParse(failure_data)); RequestDecoder testee{initial_parser_factory_, parser_resolver_, {request_callback_}}; @@ -155,14 +194,14 @@ TEST_F(RequestCodecUnitTest, shouldInvokeParsersEvenIfTheyDoNotConsumeZeroBytes) testee.onData(buffer_); // then - // There was only one message sent to `request_callback`. + // `request_callback` was invoked only once. // After that, `parser3` was created and passed remaining data (that should have been empty). } // Helper function. template void RequestCodecUnitTest::putInBuffer(T arg) { - MessageEncoderImpl serializer{buffer_}; - serializer.encode(arg); + RequestEncoder encoder{buffer_}; + encoder.encode(arg); } } // namespace RequestCodecUnitTest diff --git a/test/extensions/filters/network/kafka/serialization_utilities.cc b/test/extensions/filters/network/kafka/serialization_utilities.cc index 63eca308c269c..95129da2e4003 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.cc +++ b/test/extensions/filters/network/kafka/serialization_utilities.cc @@ -19,12 +19,23 @@ const char* getRawData(const Buffer::OwnedImpl& buffer) { return reinterpret_cast((slices[0]).mem_); } -void CapturingRequestCallback::onMessage(MessageSharedPtr message) { captured_.push_back(message); } +void CapturingRequestCallback::onMessage(AbstractRequestSharedPtr message) { + captured_.push_back(message); +} + +void CapturingRequestCallback::onFailedParse(RequestParseFailureSharedPtr failure_data) { + parse_failures_.push_back(failure_data); +} -const std::vector& CapturingRequestCallback::getCaptured() const { +const std::vector& CapturingRequestCallback::getCaptured() const { return captured_; } +const std::vector& +CapturingRequestCallback::getParseFailures() const { + return parse_failures_; +} + } // namespace Kafka } // namespace NetworkFilters } // namespace Extensions diff --git a/test/extensions/filters/network/kafka/serialization_utilities.h b/test/extensions/filters/network/kafka/serialization_utilities.h index 37a1540c040cd..d15071ffa23b6 100644 --- a/test/extensions/filters/network/kafka/serialization_utilities.h +++ b/test/extensions/filters/network/kafka/serialization_utilities.h @@ -121,15 +121,20 @@ class CapturingRequestCallback : public RequestCallback { /** * Stores the message. */ - virtual void onMessage(MessageSharedPtr request) override; + virtual void onMessage(AbstractRequestSharedPtr request) override; /** * Returns the stored messages. */ - const std::vector& getCaptured() const; + const std::vector& getCaptured() const; + + virtual void onFailedParse(RequestParseFailureSharedPtr failure_data) override; + + const std::vector& getParseFailures() const; private: - std::vector captured_; + std::vector captured_; + std::vector parse_failures_; }; typedef std::shared_ptr CapturingRequestCallbackSharedPtr;