Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 26 additions & 4 deletions source/extensions/filters/network/kafka/kafka_request.h
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,16 @@ class AbstractRequest {
AbstractRequest(const RequestHeader& request_header) : request_header_{request_header} {};

/**
* Encode the contents of this message into a given buffer.
* Computes the size of this request, if it were to be serialized.
* @return serialized size of request
*/
virtual uint32_t computeSize() const PURE;

/**
* Encode the contents of this request into a given buffer.
* @param dst buffer instance to keep serialized message
*/
virtual size_t encode(Buffer::Instance& dst) const PURE;
virtual uint32_t encode(Buffer::Instance& dst) const PURE;

/**
* Request's header.
Expand All @@ -82,12 +88,28 @@ template <typename Data> class Request : public AbstractRequest {
Request(const RequestHeader& request_header, const Data& data)
: AbstractRequest{request_header}, data_{data} {};

/**
* Compute the size of request, which includes both the request header and its real data.
*/
uint32_t computeSize() const override {
const EncodingContext context{request_header_.api_version_};
uint32_t result{0};
// Compute size of header.
result += context.computeSize(request_header_.api_key_);
result += context.computeSize(request_header_.api_version_);
result += context.computeSize(request_header_.correlation_id_);
result += context.computeSize(request_header_.client_id_);
// Compute size of request data.
result += context.computeSize(data_);
return result;
}

/**
* Encodes given request into a buffer, with any extra configuration carried by the context.
*/
size_t encode(Buffer::Instance& dst) const override {
uint32_t encode(Buffer::Instance& dst) const override {
EncodingContext context{request_header_.api_version_};
size_t written{0};
uint32_t written{0};
// Encode request header.
written += context.encode(request_header_.api_key_, dst);
written += context.encode(request_header_.api_version_, dst);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,7 @@ RequestParseResponse RequestHeaderParser::parse(absl::string_view& data) {
}

RequestParseResponse SentinelParser::parse(absl::string_view& data) {
const size_t min = std::min<size_t>(context_->remaining_request_size_, data.size());
const uint32_t min = std::min<uint32_t>(context_->remaining_request_size_, data.size());
data = {data.data() + min, data.size() - min};
context_->remaining_request_size_ -= min;
if (0 == context_->remaining_request_size_) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,19 +22,36 @@ struct {{ complex_type.name }} {
// constructor used in versions: {{ constructor['versions'] }}
{{ constructor['full_declaration'] }}{% endfor %}

{# For every field that's used in version, just compute its size using an encoder. #}
{% if complex_type.fields|length > 0 %}
uint32_t computeSize(const EncodingContext& encoder) const {
const int16_t api_version = encoder.apiVersion();
uint32_t written{0};{% for field in complex_type.fields %}
if (api_version >= {{ field.version_usage[0] }}
&& api_version < {{ field.version_usage[-1] + 1 }}) {
written += encoder.computeSize({{ field.name }}_);
}{% endfor %}
return written;
}
{% else %}
uint32_t computeSize(const EncodingContext&) const {
return 0;
}
{% endif %}

{# For every field that's used in version, just serialize it. #}
{% if complex_type.fields|length > 0 %}
size_t encode(Buffer::Instance& dst, EncodingContext& encoder) const {
uint32_t encode(Buffer::Instance& dst, EncodingContext& encoder) const {
const int16_t api_version = encoder.apiVersion();
size_t written{0};{% for field in complex_type.fields %}
uint32_t written{0};{% for field in complex_type.fields %}
if (api_version >= {{ field.version_usage[0] }}
&& api_version < {{ field.version_usage[-1] + 1 }}) {
written += encoder.encode({{ field.name }}_, dst);
}{% endfor %}
return written;
}
{% else %}
size_t encode(Buffer::Instance&, EncodingContext&) const {
uint32_t encode(Buffer::Instance&, EncodingContext&) const {
return 0;
}
{% endif %}
Expand Down
11 changes: 3 additions & 8 deletions source/extensions/filters/network/kafka/request_codec.cc
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,9 @@ void RequestDecoder::doParse(const Buffer::RawSlice& slice) {
}

void RequestEncoder::encode(const AbstractRequest& message) {
Buffer::OwnedImpl data_buffer;
// TODO(adamkotwasinski) Precompute the size instead of using temporary buffer.
// When we have the 'computeSize' method, then we can push encoding request's size into
// Request::encode
int32_t data_len = message.encode(data_buffer); // Encode data and compute data length.
EncodingContext encoder{-1};
encoder.encode(data_len, output_); // Encode data length into result.
output_.add(data_buffer); // Copy encoded data into result.
const uint32_t size = htobe32(message.computeSize());
output_.add(&size, sizeof(size)); // Encode data length.
message.encode(output_); // Encode data.
}

} // namespace Kafka
Expand Down
22 changes: 11 additions & 11 deletions source/extensions/filters/network/kafka/serialization.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,12 @@ constexpr static int32_t NULL_BYTES_LENGTH = -1;
* @return number of bytes consumed.
*/
template <typename DeserializerType, typename LengthType, typename ByteType>
size_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_deserializer,
bool& length_consumed_marker, LengthType& required,
std::vector<ByteType>& data_buffer, bool& ready,
const LengthType null_value_length, const bool allow_null_value) {
uint32_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_deserializer,
bool& length_consumed_marker, LengthType& required,
std::vector<ByteType>& data_buffer, bool& ready,
const LengthType null_value_length, const bool allow_null_value) {

const size_t length_consumed = length_deserializer.feed(data);
const uint32_t length_consumed = length_deserializer.feed(data);
if (!length_deserializer.ready()) {
// Break early: we still need to fill in length buffer.
return length_consumed;
Expand Down Expand Up @@ -68,8 +68,8 @@ size_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_de
return length_consumed;
}

const size_t data_consumed = std::min<size_t>(required, data.size());
const size_t written = data_buffer.size() - required;
const uint32_t data_consumed = std::min<uint32_t>(required, data.size());
const uint32_t written = data_buffer.size() - required;
if (data_consumed > 0) {
memcpy(data_buffer.data() + written, data.data(), data_consumed);
required -= data_consumed;
Expand All @@ -84,22 +84,22 @@ size_t feedBytesIntoBuffers(absl::string_view& data, DeserializerType& length_de
return length_consumed + data_consumed;
}

size_t StringDeserializer::feed(absl::string_view& data) {
uint32_t StringDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int16Deserializer, int16_t, char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, false);
}

size_t NullableStringDeserializer::feed(absl::string_view& data) {
uint32_t NullableStringDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int16Deserializer, int16_t, char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_STRING_LENGTH, true);
}

size_t BytesDeserializer::feed(absl::string_view& data) {
uint32_t BytesDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int32Deserializer, int32_t, unsigned char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, false);
}

size_t NullableBytesDeserializer::feed(absl::string_view& data) {
uint32_t NullableBytesDeserializer::feed(absl::string_view& data) {
return feedBytesIntoBuffers<Int32Deserializer, int32_t, unsigned char>(
data, length_buf_, length_consumed_, required_, data_buf_, ready_, NULL_BYTES_LENGTH, true);
}
Expand Down
Loading