Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions source/extensions/filters/network/kafka/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ envoy_cc_library(

envoy_cc_library(
name = "serialization_lib",
srcs = [
"serialization.cc",
],
hdrs = [
"external/serialization_composite.h",
"serialization.h",
Expand Down
110 changes: 110 additions & 0 deletions source/extensions/filters/network/kafka/serialization.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
#include "extensions/filters/network/kafka/serialization.h"

namespace Envoy {
namespace Extensions {
namespace NetworkFilters {
namespace Kafka {

constexpr static int16_t NULL_STRING_LENGTH = -1;
constexpr static int32_t NULL_BYTES_LENGTH = -1;

/**
* Helper method for deserializers that get the length of data, and then copy the given bytes into a
* local buffer. Templated as there are length and byte type differences. Impl note: This method
* modifies (sets up) most of Deserializer's fields.
* @param data bytes to deserialize.
* @param length_deserializer payload length deserializer.
* @param length_consumed_marker marker telling whether length has been extracted from
* length_deserializer, and underlying buffer has been initialized.
* @param required remaining bytes to consume.
* @param data_buffer buffer with capacity for 'required' bytes.
* @param ready marker telling whether this deserialized has finished processing.
* @param null_value_length value marking null values.
* @param allow_null_value whether null value if allowed.
* @return number of bytes consumed.
*/
template <typename DeserializerType, typename LengthType, typename ByteType>
size_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_deserializer,
bool& length_consumed_marker, LengthType& required,
std::vector<ByteType>& data_buffer, bool& ready,
const LengthType null_value_length, const bool allow_null_value) {

const size_t length_consumed = length_deserializer.feed(data);
if (!length_deserializer.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
}

if (!length_consumed_marker) {
// Length buffer is ready, but we have not yet processed the result.
// We need to extract the real data length and initialize buffer for it.
required = length_deserializer.get();

if (required >= 0) {
data_buffer = std::vector<ByteType>(required);
}

if (required == null_value_length) {
if (allow_null_value) {
// We have received 'null' value in deserializer that allows it (e.g. NullableBytes), no
// more processing is necessary.
ready = true;
} else {
// Invalid payload: null length for non-null object.
throw EnvoyException(fmt::format("invalid length: {}", required));
}
}

if (required < null_value_length) {
throw EnvoyException(fmt::format("invalid length: {}", required));
}

length_consumed_marker = true;
}

if (ready) {
// Break early: we might not need to consume any bytes for nullable values OR in case of repeat
// invocation on already-ready buffer.
return length_consumed;
}

const size_t data_consumed = std::min<size_t>(required, data.size());
const size_t written = data_buffer.size() - required;
if (data_consumed > 0) {
memcpy(data_buffer.data() + written, data.data(), data_consumed);
required -= data_consumed;
data = {data.data() + data_consumed, data.size() - data_consumed};
}

// We have consumed all the bytes, mark the deserializer as ready.
if (required == 0) {
ready = true;
}

return length_consumed + data_consumed;
}

size_t StringDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int16Deserializer, int16_t, char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, false);
}

size_t NullableStringDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int16Deserializer, int16_t, char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, true);
}

size_t BytesDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int32Deserializer, int32_t, unsigned char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, false);
}

size_t NullableBytesDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int32Deserializer, int32_t, unsigned char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, true);
}

} // namespace Kafka
} // namespace NetworkFilters
} // namespace Extensions
} // namespace Envoy
158 changes: 5 additions & 153 deletions source/extensions/filters/network/kafka/serialization.h
Original file line number Diff line number Diff line change
Expand Up @@ -179,37 +179,7 @@ class StringDeserializer : public Deserializer<std::string> {
/**
* Can throw EnvoyException if given string length is not valid.
*/
size_t feed(absl::string_view& data) override {
const size_t length_consumed = length_buf_.feed(data);
if (!length_buf_.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
}

if (!length_consumed_) {
required_ = length_buf_.get();
if (required_ >= 0) {
data_buf_ = std::vector<char>(required_);
} else {
throw EnvoyException(fmt::format("invalid STRING length: {}", required_));
}
length_consumed_ = true;
}

const size_t data_consumed = std::min<size_t>(required_, data.size());
const size_t written = data_buf_.size() - required_;
if (data_consumed > 0) {
memcpy(data_buf_.data() + written, data.data(), data_consumed);
required_ -= data_consumed;
data = {data.data() + data_consumed, data.size() - data_consumed};
}

if (required_ == 0) {
ready_ = true;
}

return length_consumed + data_consumed;
}
size_t feed(absl::string_view& data) override;

bool ready() const override { return ready_; }

Expand Down Expand Up @@ -241,47 +211,7 @@ class NullableStringDeserializer : public Deserializer<NullableString> {
/**
* Can throw EnvoyException if given string length is not valid.
*/
size_t feed(absl::string_view& data) override {
const size_t length_consumed = length_buf_.feed(data);
if (!length_buf_.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
}

if (!length_consumed_) {
required_ = length_buf_.get();

if (required_ >= 0) {
data_buf_ = std::vector<char>(required_);
}
if (required_ == NULL_STRING_LENGTH) {
ready_ = true;
}
if (required_ < NULL_STRING_LENGTH) {
throw EnvoyException(fmt::format("invalid NULLABLE_STRING length: {}", required_));
}

length_consumed_ = true;
}

if (ready_) {
return length_consumed;
}

const size_t data_consumed = std::min<size_t>(required_, data.size());
const size_t written = data_buf_.size() - required_;
if (data_consumed > 0) {
memcpy(data_buf_.data() + written, data.data(), data_consumed);
required_ -= data_consumed;
data = {data.data() + data_consumed, data.size() - data_consumed};
}

if (required_ == 0) {
ready_ = true;
}

return length_consumed + data_consumed;
}
size_t feed(absl::string_view& data) override;

bool ready() const override { return ready_; }

Expand All @@ -291,8 +221,6 @@ class NullableStringDeserializer : public Deserializer<NullableString> {
}

private:
constexpr static int16_t NULL_STRING_LENGTH{-1};

Int16Deserializer length_buf_;
bool length_consumed_{false};

Expand All @@ -314,37 +242,7 @@ class BytesDeserializer : public Deserializer<Bytes> {
/**
* Can throw EnvoyException if given bytes length is not valid.
*/
size_t feed(absl::string_view& data) override {
const size_t length_consumed = length_buf_.feed(data);
if (!length_buf_.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
}

if (!length_consumed_) {
required_ = length_buf_.get();
if (required_ >= 0) {
data_buf_ = std::vector<unsigned char>(required_);
} else {
throw EnvoyException(fmt::format("invalid BYTES length: {}", required_));
}
length_consumed_ = true;
}

const size_t data_consumed = std::min<size_t>(required_, data.size());
const size_t written = data_buf_.size() - required_;
if (data_consumed > 0) {
memcpy(data_buf_.data() + written, data.data(), data_consumed);
required_ -= data_consumed;
data = {data.data() + data_consumed, data.size() - data_consumed};
}

if (required_ == 0) {
ready_ = true;
}

return length_consumed + data_consumed;
}
size_t feed(absl::string_view& data) override;

bool ready() const override { return ready_; }

Expand Down Expand Up @@ -374,61 +272,15 @@ class NullableBytesDeserializer : public Deserializer<NullableBytes> {
/**
* Can throw EnvoyException if given bytes length is not valid.
*/
size_t feed(absl::string_view& data) override {
const size_t length_consumed = length_buf_.feed(data);
if (!length_buf_.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
}

if (!length_consumed_) {
required_ = length_buf_.get();

if (required_ >= 0) {
data_buf_ = std::vector<unsigned char>(required_);
}
if (required_ == NULL_BYTES_LENGTH) {
ready_ = true;
}
if (required_ < NULL_BYTES_LENGTH) {
throw EnvoyException(fmt::format("invalid NULLABLE_BYTES length: {}", required_));
}

length_consumed_ = true;
}

if (ready_) {
return length_consumed;
}

const size_t data_consumed = std::min<size_t>(required_, data.size());
const size_t written = data_buf_.size() - required_;
if (data_consumed > 0) {
memcpy(data_buf_.data() + written, data.data(), data_consumed);
required_ -= data_consumed;
data = {data.data() + data_consumed, data.size() - data_consumed};
}

if (required_ == 0) {
ready_ = true;
}

return length_consumed + data_consumed;
}
size_t feed(absl::string_view& data) override;

bool ready() const override { return ready_; }

NullableBytes get() const override {
if (NULL_BYTES_LENGTH == required_) {
return absl::nullopt;
} else {
return {data_buf_};
}
return required_ >= 0 ? absl::make_optional(data_buf_) : absl::nullopt;
}

private:
constexpr static int32_t NULL_BYTES_LENGTH{-1};

Int32Deserializer length_buf_;
bool length_consumed_{false};
int32_t required_;
Expand Down