diff --git a/source/extensions/filters/network/kafka/BUILD b/source/extensions/filters/network/kafka/BUILD index 73bab5124c8e7..37bfa7a5daba7 100644 --- a/source/extensions/filters/network/kafka/BUILD +++ b/source/extensions/filters/network/kafka/BUILD @@ -91,6 +91,9 @@ envoy_cc_library( envoy_cc_library( name = "serialization_lib", + srcs = [ + "serialization.cc", + ], hdrs = [ "external/serialization_composite.h", "serialization.h", diff --git a/source/extensions/filters/network/kafka/serialization.cc b/source/extensions/filters/network/kafka/serialization.cc new file mode 100644 index 0000000000000..6d60e01a452fe --- /dev/null +++ b/source/extensions/filters/network/kafka/serialization.cc @@ -0,0 +1,110 @@ +#include "extensions/filters/network/kafka/serialization.h" + +namespace Envoy { +namespace Extensions { +namespace NetworkFilters { +namespace Kafka { + +constexpr static int16_t NULL_STRING_LENGTH = -1; +constexpr static int32_t NULL_BYTES_LENGTH = -1; + +/** + * Helper method for deserializers that get the length of data, and then copy the given bytes into a + * local buffer. Templated as there are length and byte type differences. Impl note: This method + * modifies (sets up) most of Deserializer's fields. + * @param data bytes to deserialize. + * @param length_deserializer payload length deserializer. + * @param length_consumed_marker marker telling whether length has been extracted from + * length_deserializer, and underlying buffer has been initialized. + * @param required remaining bytes to consume. + * @param data_buffer buffer with capacity for 'required' bytes. + * @param ready marker telling whether this deserialized has finished processing. + * @param null_value_length value marking null values. + * @param allow_null_value whether null value if allowed. + * @return number of bytes consumed. + */ +template +size_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_deserializer, + bool& length_consumed_marker, LengthType& required, + std::vector& data_buffer, bool& ready, + const LengthType null_value_length, const bool allow_null_value) { + + const size_t length_consumed = length_deserializer.feed(data); + if (!length_deserializer.ready()) { + // Break early: we still need to fill in length buffer. + return length_consumed; + } + + if (!length_consumed_marker) { + // Length buffer is ready, but we have not yet processed the result. + // We need to extract the real data length and initialize buffer for it. + required = length_deserializer.get(); + + if (required >= 0) { + data_buffer = std::vector(required); + } + + if (required == null_value_length) { + if (allow_null_value) { + // We have received 'null' value in deserializer that allows it (e.g. NullableBytes), no + // more processing is necessary. + ready = true; + } else { + // Invalid payload: null length for non-null object. + throw EnvoyException(fmt::format("invalid length: {}", required)); + } + } + + if (required < null_value_length) { + throw EnvoyException(fmt::format("invalid length: {}", required)); + } + + length_consumed_marker = true; + } + + if (ready) { + // Break early: we might not need to consume any bytes for nullable values OR in case of repeat + // invocation on already-ready buffer. + return length_consumed; + } + + const size_t data_consumed = std::min(required, data.size()); + const size_t written = data_buffer.size() - required; + if (data_consumed > 0) { + memcpy(data_buffer.data() + written, data.data(), data_consumed); + required -= data_consumed; + data = {data.data() + data_consumed, data.size() - data_consumed}; + } + + // We have consumed all the bytes, mark the deserializer as ready. + if (required == 0) { + ready = true; + } + + return length_consumed + data_consumed; +} + +size_t StringDeserializer::feed(absl::string_view& data) { + return feedBytesIntoBuffers( + data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, false); +} + +size_t NullableStringDeserializer::feed(absl::string_view& data) { + return feedBytesIntoBuffers( + data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, true); +} + +size_t BytesDeserializer::feed(absl::string_view& data) { + return feedBytesIntoBuffers( + data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, false); +} + +size_t NullableBytesDeserializer::feed(absl::string_view& data) { + return feedBytesIntoBuffers( + data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, true); +} + +} // namespace Kafka +} // namespace NetworkFilters +} // namespace Extensions +} // namespace Envoy diff --git a/source/extensions/filters/network/kafka/serialization.h b/source/extensions/filters/network/kafka/serialization.h index 7600a4c374d81..c8d3820ea2302 100644 --- a/source/extensions/filters/network/kafka/serialization.h +++ b/source/extensions/filters/network/kafka/serialization.h @@ -179,37 +179,7 @@ class StringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid. */ - size_t feed(absl::string_view& data) override { - const size_t length_consumed = length_buf_.feed(data); - if (!length_buf_.ready()) { - // Break early: we still need to fill in length buffer. - return length_consumed; - } - - if (!length_consumed_) { - required_ = length_buf_.get(); - if (required_ >= 0) { - data_buf_ = std::vector(required_); - } else { - throw EnvoyException(fmt::format("invalid STRING length: {}", required_)); - } - length_consumed_ = true; - } - - const size_t data_consumed = std::min(required_, data.size()); - const size_t written = data_buf_.size() - required_; - if (data_consumed > 0) { - memcpy(data_buf_.data() + written, data.data(), data_consumed); - required_ -= data_consumed; - data = {data.data() + data_consumed, data.size() - data_consumed}; - } - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } + size_t feed(absl::string_view& data) override; bool ready() const override { return ready_; } @@ -241,47 +211,7 @@ class NullableStringDeserializer : public Deserializer { /** * Can throw EnvoyException if given string length is not valid. */ - size_t feed(absl::string_view& data) override { - const size_t length_consumed = length_buf_.feed(data); - if (!length_buf_.ready()) { - // Break early: we still need to fill in length buffer. - return length_consumed; - } - - if (!length_consumed_) { - required_ = length_buf_.get(); - - if (required_ >= 0) { - data_buf_ = std::vector(required_); - } - if (required_ == NULL_STRING_LENGTH) { - ready_ = true; - } - if (required_ < NULL_STRING_LENGTH) { - throw EnvoyException(fmt::format("invalid NULLABLE_STRING length: {}", required_)); - } - - length_consumed_ = true; - } - - if (ready_) { - return length_consumed; - } - - const size_t data_consumed = std::min(required_, data.size()); - const size_t written = data_buf_.size() - required_; - if (data_consumed > 0) { - memcpy(data_buf_.data() + written, data.data(), data_consumed); - required_ -= data_consumed; - data = {data.data() + data_consumed, data.size() - data_consumed}; - } - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } + size_t feed(absl::string_view& data) override; bool ready() const override { return ready_; } @@ -291,8 +221,6 @@ class NullableStringDeserializer : public Deserializer { } private: - constexpr static int16_t NULL_STRING_LENGTH{-1}; - Int16Deserializer length_buf_; bool length_consumed_{false}; @@ -314,37 +242,7 @@ class BytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid. */ - size_t feed(absl::string_view& data) override { - const size_t length_consumed = length_buf_.feed(data); - if (!length_buf_.ready()) { - // Break early: we still need to fill in length buffer. - return length_consumed; - } - - if (!length_consumed_) { - required_ = length_buf_.get(); - if (required_ >= 0) { - data_buf_ = std::vector(required_); - } else { - throw EnvoyException(fmt::format("invalid BYTES length: {}", required_)); - } - length_consumed_ = true; - } - - const size_t data_consumed = std::min(required_, data.size()); - const size_t written = data_buf_.size() - required_; - if (data_consumed > 0) { - memcpy(data_buf_.data() + written, data.data(), data_consumed); - required_ -= data_consumed; - data = {data.data() + data_consumed, data.size() - data_consumed}; - } - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } + size_t feed(absl::string_view& data) override; bool ready() const override { return ready_; } @@ -374,61 +272,15 @@ class NullableBytesDeserializer : public Deserializer { /** * Can throw EnvoyException if given bytes length is not valid. */ - size_t feed(absl::string_view& data) override { - const size_t length_consumed = length_buf_.feed(data); - if (!length_buf_.ready()) { - // Break early: we still need to fill in length buffer. - return length_consumed; - } - - if (!length_consumed_) { - required_ = length_buf_.get(); - - if (required_ >= 0) { - data_buf_ = std::vector(required_); - } - if (required_ == NULL_BYTES_LENGTH) { - ready_ = true; - } - if (required_ < NULL_BYTES_LENGTH) { - throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_)); - } - - length_consumed_ = true; - } - - if (ready_) { - return length_consumed; - } - - const size_t data_consumed = std::min(required_, data.size()); - const size_t written = data_buf_.size() - required_; - if (data_consumed > 0) { - memcpy(data_buf_.data() + written, data.data(), data_consumed); - required_ -= data_consumed; - data = {data.data() + data_consumed, data.size() - data_consumed}; - } - - if (required_ == 0) { - ready_ = true; - } - - return length_consumed + data_consumed; - } + size_t feed(absl::string_view& data) override; bool ready() const override { return ready_; } NullableBytes get() const override { - if (NULL_BYTES_LENGTH == required_) { - return absl::nullopt; - } else { - return {data_buf_}; - } + return required_ >= 0 ? absl::make_optional(data_buf_) : absl::nullopt; } private: - constexpr static int32_t NULL_BYTES_LENGTH{-1}; - Int32Deserializer length_buf_; bool length_consumed_{false}; int32_t required_;