Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion source/extensions/filters/network/thrift_proxy/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,8 @@ envoy_cc_library(
hdrs = ["app_exception_impl.h"],
deps = [
":protocol_interface",
":thrift_lib",
"//include/envoy/buffer:buffer_interface",
"//source/extensions/filters/network/thrift_proxy/filters:filter_interface",
],
)

Expand All @@ -35,6 +35,7 @@ envoy_cc_library(
srcs = ["config.cc"],
hdrs = ["config.h"],
deps = [
":app_exception_lib",
":conn_manager_lib",
":decoder_lib",
":protocol_lib",
Expand Down Expand Up @@ -88,6 +89,17 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "metadata_lib",
srcs = ["metadata.cc"],
hdrs = ["metadata.h"],
external_deps = ["abseil_optional"],
deps = [
":thrift_lib",
"//source/common/common:macros",
],
)

envoy_cc_library(
name = "protocol_converter_lib",
hdrs = [
Expand All @@ -107,6 +119,8 @@ envoy_cc_library(
],
external_deps = ["abseil_optional"],
deps = [
":metadata_lib",
":thrift_lib",
"//include/envoy/buffer:buffer_interface",
"//include/envoy/registry",
"//source/common/common:assert_lib",
Expand Down Expand Up @@ -149,6 +163,9 @@ envoy_cc_library(
hdrs = ["transport.h"],
external_deps = ["abseil_optional"],
deps = [
":buffer_helper_lib",
":metadata_lib",
":thrift_lib",
"//include/envoy/buffer:buffer_interface",
"//include/envoy/registry",
"//source/common/common:assert_lib",
Expand All @@ -157,6 +174,15 @@ envoy_cc_library(
],
)

envoy_cc_library(
name = "thrift_lib",
hdrs = ["thrift.h"],
deps = [
"//source/common/common:assert_lib",
"//source/common/singleton:const_singleton",
],
)

envoy_cc_library(
name = "transport_lib",
srcs = [
Expand All @@ -170,7 +196,9 @@ envoy_cc_library(
"unframed_transport_impl.h",
],
deps = [
":app_exception_lib",
":buffer_helper_lib",
":metadata_lib",
":protocol_lib",
":transport_interface",
"//source/common/common:assert_lib",
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,19 +10,31 @@ static const std::string MessageField = "message";
static const std::string TypeField = "type";
static const std::string StopField = "";

void AppException::encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) {
proto.writeMessageBegin(buffer, method_name_, ThriftProxy::MessageType::Exception, seq_id_);
void AppException::encode(MessageMetadata& metadata, ThriftProxy::Protocol& proto,
Buffer::Instance& buffer) const {
// Handle cases where the exception occurs before the message name (e.g. some header transport
// errors).
if (!metadata.hasMethodName()) {
metadata.setMethodName("");
}
if (!metadata.hasSequenceId()) {
metadata.setSequenceId(0);
}

metadata.setMessageType(MessageType::Exception);

proto.writeMessageBegin(buffer, metadata);
proto.writeStructBegin(buffer, TApplicationException);

proto.writeFieldBegin(buffer, MessageField, ThriftProxy::FieldType::String, 1);
proto.writeString(buffer, error_message_);
proto.writeFieldBegin(buffer, MessageField, FieldType::String, 1);
proto.writeString(buffer, std::string(what()));
proto.writeFieldEnd(buffer);

proto.writeFieldBegin(buffer, TypeField, ThriftProxy::FieldType::I32, 2);
proto.writeFieldBegin(buffer, TypeField, FieldType::I32, 2);
proto.writeInt32(buffer, static_cast<int32_t>(type_));
proto.writeFieldEnd(buffer);

proto.writeFieldBegin(buffer, StopField, ThriftProxy::FieldType::Stop, 0);
proto.writeFieldBegin(buffer, StopField, FieldType::Stop, 0);

proto.writeStructEnd(buffer);
proto.writeMessageEnd(buffer);
Expand Down
37 changes: 10 additions & 27 deletions source/extensions/filters/network/thrift_proxy/app_exception_impl.h
Original file line number Diff line number Diff line change
@@ -1,41 +1,24 @@
#pragma once

#include "extensions/filters/network/thrift_proxy/filters/filter.h"
#include "envoy/common/exception.h"

#include "extensions/filters/network/thrift_proxy/metadata.h"
#include "extensions/filters/network/thrift_proxy/protocol.h"
#include "extensions/filters/network/thrift_proxy/thrift.h"

namespace Envoy {
namespace Extensions {
namespace NetworkFilters {
namespace ThriftProxy {

/**
* Thrift Application Exception types.
* See https://github.com/apache/thrift/blob/master/doc/specs/thrift-rpc.md
*/
enum class AppExceptionType {
Unknown = 0,
UnknownMethod = 1,
InvalidMessageType = 2,
WrongMethodName = 3,
BadSequenceId = 4,
MissingResult = 5,
InternalError = 6,
ProtocolError = 7,
InvalidTransform = 8,
InvalidProtocol = 9,
UnsupportedClientType = 10,
};

struct AppException : public ThriftFilters::DirectResponse {
AppException(const absl::string_view method_name, int32_t seq_id, AppExceptionType type,
const std::string& error_message)
: method_name_(method_name), seq_id_(seq_id), type_(type), error_message_(error_message) {}
struct AppException : public EnvoyException, public DirectResponse {
AppException(AppExceptionType type, const std::string& what)
: EnvoyException(what), type_(type) {}
AppException(const AppException& ex) : EnvoyException(ex.what()), type_(ex.type_) {}

void encode(ThriftProxy::Protocol& proto, Buffer::Instance& buffer) override;
void encode(MessageMetadata& metadata, Protocol& proto, Buffer::Instance& buffer) const override;

const std::string method_name_;
const int32_t seq_id_;
const AppExceptionType type_;
const std::string error_message_;
};

} // namespace ThriftProxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,7 @@ namespace ThriftProxy {

const uint16_t BinaryProtocolImpl::Magic = 0x8001;

bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name,
MessageType& msg_type, int32_t& seq_id) {
bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
// Minimum message length:
// version: 2 bytes +
// unused: 1 byte +
Expand Down Expand Up @@ -52,13 +51,14 @@ bool BinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string&
buffer.drain(8);

if (name_len > 0) {
name.assign(std::string(static_cast<char*>(buffer.linearize(name_len)), name_len));
metadata.setMethodName(
std::string(static_cast<const char*>(buffer.linearize(name_len)), name_len));
buffer.drain(name_len);
} else {
name.clear();
metadata.setMethodName("");
}
msg_type = type;
seq_id = BufferHelper::drainI32(buffer);
metadata.setMessageType(type);
metadata.setSequenceId(BufferHelper::drainI32(buffer));

return true;
}
Expand Down Expand Up @@ -253,7 +253,7 @@ bool BinaryProtocolImpl::readString(Buffer::Instance& buffer, std::string& value
}

buffer.drain(4);
value.assign(static_cast<char*>(buffer.linearize(str_len)), str_len);
value.assign(static_cast<const char*>(buffer.linearize(str_len)), str_len);
buffer.drain(str_len);
return true;
}
Expand All @@ -262,12 +262,12 @@ bool BinaryProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& value
return readString(buffer, value);
}

void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name,
MessageType msg_type, int32_t seq_id) {
void BinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
const MessageMetadata& metadata) {
BufferHelper::writeU16(buffer, Magic);
BufferHelper::writeU16(buffer, static_cast<uint16_t>(msg_type));
writeString(buffer, name);
BufferHelper::writeI32(buffer, seq_id);
BufferHelper::writeU16(buffer, static_cast<uint16_t>(metadata.messageType()));
writeString(buffer, metadata.methodName());
BufferHelper::writeI32(buffer, metadata.sequenceId());
}

void BinaryProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) {
Expand Down Expand Up @@ -362,8 +362,7 @@ void BinaryProtocolImpl::writeBinary(Buffer::Instance& buffer, const std::string
writeString(buffer, value);
}

bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name,
MessageType& msg_type, int32_t& seq_id) {
bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
// Minimum message length:
// name len: 4 bytes +
// name: 0 bytes +
Expand All @@ -387,24 +386,25 @@ bool LaxBinaryProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::stri

buffer.drain(4);
if (name_len > 0) {
name.assign(std::string(static_cast<char*>(buffer.linearize(name_len)), name_len));
metadata.setMethodName(
std::string(static_cast<const char*>(buffer.linearize(name_len)), name_len));
buffer.drain(name_len);
} else {
name.clear();
metadata.setMethodName("");
}

msg_type = type;
seq_id = BufferHelper::peekI32(buffer, 1);
metadata.setMessageType(type);
metadata.setSequenceId(BufferHelper::peekI32(buffer, 1));
buffer.drain(5);

return true;
}

void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name,
MessageType msg_type, int32_t seq_id) {
writeString(buffer, name);
BufferHelper::writeI8(buffer, static_cast<int8_t>(msg_type));
BufferHelper::writeI32(buffer, seq_id);
void LaxBinaryProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
const MessageMetadata& metadata) {
writeString(buffer, metadata.methodName());
BufferHelper::writeI8(buffer, static_cast<int8_t>(metadata.messageType()));
BufferHelper::writeI32(buffer, metadata.sequenceId());
}

class BinaryProtocolConfigFactory : public ProtocolFactoryBase<BinaryProtocolImpl> {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,7 @@ class BinaryProtocolImpl : public Protocol {
// Protocol
const std::string& name() const override { return ProtocolNames::get().BINARY; }
ProtocolType type() const override { return ProtocolType::Binary; }
bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type,
int32_t& seq_id) override;
bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override;
bool readMessageEnd(Buffer::Instance& buffer) override;
bool readStructBegin(Buffer::Instance& buffer, std::string& name) override;
bool readStructEnd(Buffer::Instance& buffer) override;
Expand All @@ -46,8 +45,7 @@ class BinaryProtocolImpl : public Protocol {
bool readDouble(Buffer::Instance& buffer, double& value) override;
bool readString(Buffer::Instance& buffer, std::string& value) override;
bool readBinary(Buffer::Instance& buffer, std::string& value) override;
void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type,
int32_t seq_id) override;
void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override;
void writeMessageEnd(Buffer::Instance& buffer) override;
void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override;
void writeStructEnd(Buffer::Instance& buffer) override;
Expand Down Expand Up @@ -86,10 +84,8 @@ class LaxBinaryProtocolImpl : public BinaryProtocolImpl {

const std::string& name() const override { return ProtocolNames::get().LAX_BINARY; }

bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type,
int32_t& seq_id) override;
void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type,
int32_t seq_id) override;
bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override;
void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override;
};

} // namespace ThriftProxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,7 @@ namespace ThriftProxy {
const uint16_t CompactProtocolImpl::Magic = 0x8201;
const uint16_t CompactProtocolImpl::MagicMask = 0xFF1F;

bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string& name,
MessageType& msg_type, int32_t& seq_id) {
bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) {
// Minimum message length:
// protocol, message type, and version: 2 bytes +
// seq id (var int): 1 byte +
Expand Down Expand Up @@ -64,13 +63,14 @@ bool CompactProtocolImpl::readMessageBegin(Buffer::Instance& buffer, std::string
buffer.drain(id_size + name_len_size + 2);

if (name_len > 0) {
name.assign(std::string(static_cast<char*>(buffer.linearize(name_len)), name_len));
metadata.setMethodName(
std::string(static_cast<const char*>(buffer.linearize(name_len)), name_len));
buffer.drain(name_len);
} else {
name.clear();
metadata.setMethodName("");
}
msg_type = type;
seq_id = id;
metadata.setMessageType(type);
metadata.setSequenceId(id);

return true;
}
Expand Down Expand Up @@ -373,7 +373,7 @@ bool CompactProtocolImpl::readString(Buffer::Instance& buffer, std::string& valu
}

buffer.drain(len_size);
value.assign(static_cast<char*>(buffer.linearize(str_len)), str_len);
value.assign(static_cast<const char*>(buffer.linearize(str_len)), str_len);
buffer.drain(str_len);
return true;
}
Expand All @@ -382,17 +382,17 @@ bool CompactProtocolImpl::readBinary(Buffer::Instance& buffer, std::string& valu
return readString(buffer, value);
}

void CompactProtocolImpl::writeMessageBegin(Buffer::Instance& buffer, const std::string& name,
MessageType msg_type, int32_t seq_id) {
UNREFERENCED_PARAMETER(name);
void CompactProtocolImpl::writeMessageBegin(Buffer::Instance& buffer,
const MessageMetadata& metadata) {
MessageType msg_type = metadata.messageType();

uint16_t ptv = (Magic & MagicMask) | (static_cast<uint16_t>(msg_type) << 5);
ASSERT((ptv & MagicMask) == Magic);
ASSERT((ptv & ~MagicMask) >> 5 == static_cast<uint16_t>(msg_type));

BufferHelper::writeU16(buffer, ptv);
BufferHelper::writeVarIntI32(buffer, seq_id);
writeString(buffer, name);
BufferHelper::writeVarIntI32(buffer, metadata.sequenceId());
writeString(buffer, metadata.methodName());
}

void CompactProtocolImpl::writeMessageEnd(Buffer::Instance& buffer) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,7 @@ class CompactProtocolImpl : public Protocol {
// Protocol
const std::string& name() const override { return ProtocolNames::get().COMPACT; }
ProtocolType type() const override { return ProtocolType::Compact; }
bool readMessageBegin(Buffer::Instance& buffer, std::string& name, MessageType& msg_type,
int32_t& seq_id) override;
bool readMessageBegin(Buffer::Instance& buffer, MessageMetadata& metadata) override;
bool readMessageEnd(Buffer::Instance& buffer) override;
bool readStructBegin(Buffer::Instance& buffer, std::string& name) override;
bool readStructEnd(Buffer::Instance& buffer) override;
Expand All @@ -49,8 +48,7 @@ class CompactProtocolImpl : public Protocol {
bool readDouble(Buffer::Instance& buffer, double& value) override;
bool readString(Buffer::Instance& buffer, std::string& value) override;
bool readBinary(Buffer::Instance& buffer, std::string& value) override;
void writeMessageBegin(Buffer::Instance& buffer, const std::string& name, MessageType msg_type,
int32_t seq_id) override;
void writeMessageBegin(Buffer::Instance& buffer, const MessageMetadata& metadata) override;
void writeMessageEnd(Buffer::Instance& buffer) override;
void writeStructBegin(Buffer::Instance& buffer, const std::string& name) override;
void writeStructEnd(Buffer::Instance& buffer) override;
Expand Down
4 changes: 2 additions & 2 deletions source/extensions/filters/network/thrift_proxy/config.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,8 +45,8 @@ class ConfigImpl : public Config,
void createFilterChain(ThriftFilters::FilterChainFactoryCallbacks& callbacks) override;

// Router::Config
Router::RouteConstSharedPtr route(const std::string& method_name) const override {
return route_matcher_->route(method_name);
Router::RouteConstSharedPtr route(const MessageMetadata& metadata) const override {
return route_matcher_->route(metadata);
}

// Config
Expand Down
Loading