Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions source/extensions/filters/network/thrift_proxy/conn_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ Network::FilterStatus ConnectionManager::onData(Buffer::Instance& data, bool end

void ConnectionManager::dispatch() {
if (stopped_) {
ENVOY_LOG(debug, "thrift filter stopped");
ENVOY_CONN_LOG(debug, "thrift filter stopped", read_callbacks_->connection());
return;
}

Expand All @@ -55,7 +55,7 @@ void ConnectionManager::dispatch() {
sendLocalReply(*(*rpcs_.begin())->metadata_, ex);
}
} catch (const EnvoyException& ex) {
ENVOY_LOG(error, "thrift error: {}", ex.what());
ENVOY_CONN_LOG(error, "thrift error: {}", read_callbacks_->connection(), ex.what());

// Use the current rpc to send an error downstream, if possible.
rpcs_.front()->onError(ex.what());
Expand Down Expand Up @@ -88,6 +88,7 @@ void ConnectionManager::sendLocalReply(MessageMetadata& metadata, const DirectRe
}

void ConnectionManager::continueDecoding() {
ENVOY_CONN_LOG(debug, "thrift filter continued", read_callbacks_->connection());
stopped_ = false;
dispatch();
}
Expand All @@ -104,6 +105,9 @@ void ConnectionManager::resetAllRpcs() {

void ConnectionManager::initializeReadFilterCallbacks(Network::ReadFilterCallbacks& callbacks) {
read_callbacks_ = &callbacks;

read_callbacks_->connection().addConnectionCallbacks(*this);
read_callbacks_->connection().enableHalfClose(true);
}

void ConnectionManager::onEvent(Network::ConnectionEvent event) {
Expand All @@ -119,7 +123,7 @@ void ConnectionManager::onEvent(Network::ConnectionEvent event) {
}

ThriftFilters::DecoderFilter& ConnectionManager::newDecoderFilter() {
ENVOY_LOG(debug, "new decoder filter");
ENVOY_LOG(trace, "new decoder filter");

ActiveRpcPtr new_rpc(new ActiveRpc(*this));
new_rpc->createFilterChain();
Expand Down Expand Up @@ -296,7 +300,8 @@ bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) {
decoder_filter_->resetUpstreamConnection();
return true;
} catch (const EnvoyException& ex) {
ENVOY_LOG(error, "thrift response error: {}", ex.what());
ENVOY_CONN_LOG(error, "thrift response error: {}", parent_.read_callbacks_->connection(),
ex.what());
parent_.stats_.response_decoding_error_.inc();

onError(ex.what());
Expand All @@ -307,7 +312,6 @@ bool ConnectionManager::ActiveRpc::upstreamData(Buffer::Instance& buffer) {

void ConnectionManager::ActiveRpc::resetDownstreamConnection() {
parent_.read_callbacks_->connection().close(Network::ConnectionCloseType::NoFlush);
parent_.doDeferredRpcDestroy(*this);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

With this line removed, where does the cleanup happen instead?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because the downstream connection is closed, ConnectionManager::onEvent is invoked, which will destroy any items in the rpcs_ list. What was happening before was that the RPC was being destroyed but not removed from rpcs_.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

(Which causes a subsequent nullptr dereference.)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, now it makes sense.

}

} // namespace ThriftProxy
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -137,8 +137,7 @@ ThriftFilters::FilterStatus Router::messageBegin(MessageMetadataSharedPtr metada
ENVOY_STREAM_LOG(debug, "router decoding request", *callbacks_);

upstream_request_.reset(new UpstreamRequest(*this, *conn_pool, metadata));
upstream_request_->start();
return ThriftFilters::FilterStatus::StopIteration;
return upstream_request_->start();
}

ThriftFilters::FilterStatus Router::messageEnd() {
Expand Down Expand Up @@ -215,14 +214,22 @@ Router::UpstreamRequest::UpstreamRequest(Router& parent, Tcp::ConnectionPool::In

Router::UpstreamRequest::~UpstreamRequest() {}

void Router::UpstreamRequest::start() {
ThriftFilters::FilterStatus Router::UpstreamRequest::start() {
Tcp::ConnectionPool::Cancellable* handle = conn_pool_.newConnection(*this);
if (handle) {
// Pause while we wait for a connection.
conn_pool_handle_ = handle;
return ThriftFilters::FilterStatus::StopIteration;
}

return ThriftFilters::FilterStatus::Continue;
}

void Router::UpstreamRequest::resetStream() {
if (conn_pool_handle_) {
conn_pool_handle_->cancel();
}

if (conn_data_ != nullptr) {
conn_data_->connection().close(Network::ConnectionCloseType::NoFlush);
conn_data_.reset();
Expand All @@ -231,13 +238,18 @@ void Router::UpstreamRequest::resetStream() {

void Router::UpstreamRequest::onPoolFailure(Tcp::ConnectionPool::PoolFailureReason reason,
Upstream::HostDescriptionConstSharedPtr host) {
conn_pool_handle_ = nullptr;

// Mimic an upstream reset.
onUpstreamHostSelected(host);
onResetStream(reason);
}

void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr&& conn_data,
Upstream::HostDescriptionConstSharedPtr host) {
// Only invoke continueDecoding if we'd previously stopped the filter chain.
bool continue_decoding = conn_pool_handle_ != nullptr;

onUpstreamHostSelected(host);
conn_data_ = std::move(conn_data);
conn_data_->addUpstreamCallbacks(parent_);
Expand All @@ -257,7 +269,9 @@ void Router::UpstreamRequest::onPoolReady(Tcp::ConnectionPool::ConnectionDataPtr
// TODO(zuercher): need to use an upstream-connection-specific sequence id
parent_.convertMessageBegin(metadata_);

parent_.callbacks_->continueDecoding();
if (continue_decoding) {
parent_.callbacks_->continueDecoding();
}
}

void Router::UpstreamRequest::onRequestComplete() { request_complete_ = true; }
Expand All @@ -272,6 +286,13 @@ void Router::UpstreamRequest::onUpstreamHostSelected(Upstream::HostDescriptionCo
}

void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReason reason) {
if (metadata_->messageType() == MessageType::Oneway) {
// For oneway requests, we should not attempt a response. Reset the downstream to signal
// an error.
parent_.callbacks_->resetDownstreamConnection();
return;
}

switch (reason) {
case Tcp::ConnectionPool::PoolFailureReason::Overflow:
parent_.callbacks_->sendLocalReply(AppException(
Expand All @@ -289,6 +310,7 @@ void Router::UpstreamRequest::onResetStream(Tcp::ConnectionPool::PoolFailureReas
return;
}

// Error occurred after a partial response, propagate the reset to the downstream.
parent_.callbacks_->resetDownstreamConnection();
break;
default:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -105,7 +105,7 @@ class Router : public Tcp::ConnectionPool::UpstreamCallbacks,
MessageMetadataSharedPtr& metadata);
~UpstreamRequest();

void start();
ThriftFilters::FilterStatus start();
void resetStream();

// Tcp::ConnectionPool::Callbacks
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def main(cfg):
elif cfg.response == "exception":
print("Thrift Server will throw Thrift exceptions for all messages")

server = TServer.TSimpleServer(processor, transport, transport_factory, protocol_factory)
server = TServer.TThreadedServer(processor, transport, transport_factory, protocol_factory)
try:
server.serve()
except KeyboardInterrupt:
Expand Down
18 changes: 18 additions & 0 deletions test/extensions/filters/network/thrift_proxy/integration_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -272,6 +272,24 @@ TEST_P(ThriftConnManagerIntegrationTest, Oneway) {
EXPECT_EQ(1U, counter->value());
}

TEST_P(ThriftConnManagerIntegrationTest, OnewayEarlyClose) {
initializeOneway();

IntegrationTcpClientPtr tcp_client = makeTcpConnection(lookupPort("listener_0"));
tcp_client->write(request_bytes_.toString());
tcp_client->close();

FakeRawConnectionPtr fake_upstream_connection;
ASSERT_TRUE(fake_upstreams_[0]->waitForRawConnection(fake_upstream_connection));
std::string data;
ASSERT_TRUE(fake_upstream_connection->waitForData(request_bytes_.length(), &data));
Buffer::OwnedImpl upstream_request(data);
EXPECT_EQ(request_bytes_.toString(), upstream_request.toString());

Stats::CounterSharedPtr counter = test_server_->counter("thrift.thrift_stats.request_oneway");
EXPECT_EQ(1U, counter->value());
}

} // namespace ThriftProxy
} // namespace NetworkFilters
} // namespace Extensions
Expand Down
115 changes: 115 additions & 0 deletions test/extensions/filters/network/thrift_proxy/router_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,56 @@ class ThriftRouterTestBase {
EXPECT_NE(nullptr, upstream_callbacks_);
}

void startRequestWithExistingConnection(MessageType msg_type) {
EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->transportBegin({}));

EXPECT_CALL(callbacks_, route()).WillOnce(Return(route_ptr_));
EXPECT_CALL(*route_, routeEntry()).WillOnce(Return(&route_entry_));
EXPECT_CALL(route_entry_, clusterName()).WillRepeatedly(ReturnRef(cluster_name_));

initializeMetadata(msg_type);

EXPECT_CALL(*context_.cluster_manager_.tcp_conn_pool_.connection_data_, addUpstreamCallbacks(_))
.WillOnce(Invoke([&](Tcp::ConnectionPool::UpstreamCallbacks& cb) -> void {
upstream_callbacks_ = &cb;
}));

NiceMock<Network::MockClientConnection> connection;
EXPECT_CALL(callbacks_, connection()).WillRepeatedly(Return(&connection));
EXPECT_EQ(&connection, router_->downstreamConnection());

// Not yet implemented:
EXPECT_EQ(absl::optional<uint64_t>(), router_->computeHashKey());
EXPECT_EQ(nullptr, router_->metadataMatchCriteria());
EXPECT_EQ(nullptr, router_->downstreamHeaders());

EXPECT_CALL(callbacks_, downstreamTransportType()).WillOnce(Return(TransportType::Framed));
transport_ = new NiceMock<MockTransport>();
ON_CALL(*transport_, type()).WillByDefault(Return(TransportType::Framed));

EXPECT_CALL(callbacks_, downstreamProtocolType()).WillOnce(Return(ProtocolType::Binary));
protocol_ = new NiceMock<MockProtocol>();
ON_CALL(*protocol_, type()).WillByDefault(Return(ProtocolType::Binary));
EXPECT_CALL(*protocol_, writeMessageBegin(_, _))
.WillOnce(Invoke([&](Buffer::Instance&, const MessageMetadata& metadata) -> void {
EXPECT_EQ(metadata_->methodName(), metadata.methodName());
EXPECT_EQ(metadata_->messageType(), metadata.messageType());
EXPECT_EQ(metadata_->sequenceId(), metadata.sequenceId());
}));

EXPECT_CALL(callbacks_, continueDecoding()).Times(0);
EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_, newConnection(_))
.WillOnce(
Invoke([&](Tcp::ConnectionPool::Callbacks& cb) -> Tcp::ConnectionPool::Cancellable* {
context_.cluster_manager_.tcp_conn_pool_.newConnectionImpl(cb);
context_.cluster_manager_.tcp_conn_pool_.poolReady(upstream_connection_);
return nullptr;
}));

EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->messageBegin(metadata_));
EXPECT_NE(nullptr, upstream_callbacks_);
}

void sendTrivialStruct(FieldType field_type) {
EXPECT_CALL(*protocol_, writeStructBegin(_, ""));
EXPECT_EQ(ThriftFilters::FilterStatus::Continue, router_->structBegin({}));
Expand Down Expand Up @@ -334,6 +384,18 @@ TEST_F(ThriftRouterTest, PoolOverflowFailure) {
Tcp::ConnectionPool::PoolFailureReason::Overflow);
}

TEST_F(ThriftRouterTest, PoolConnectionFailureWithOnewayMessage) {
initializeRouter();
startRequest(MessageType::Oneway);

EXPECT_CALL(callbacks_, sendLocalReply(_)).Times(0);
EXPECT_CALL(callbacks_, resetDownstreamConnection());
context_.cluster_manager_.tcp_conn_pool_.poolFailure(
Tcp::ConnectionPool::PoolFailureReason::RemoteConnectionFailure);

destroyRouter();
}

TEST_F(ThriftRouterTest, NoRoute) {
initializeRouter();
initializeMetadata(MessageType::Call);
Expand Down Expand Up @@ -422,6 +484,47 @@ TEST_F(ThriftRouterTest, TruncatedResponse) {
destroyRouter();
}

TEST_F(ThriftRouterTest, UpstreamRemoteCloseMidResponse) {
initializeRouter();
startRequest(MessageType::Call);
connectUpstream();

EXPECT_CALL(callbacks_, sendLocalReply(_))
.WillOnce(Invoke([&](const DirectResponse& response) -> void {
auto& app_ex = dynamic_cast<const AppException&>(response);
EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_);
EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*"));
}));
upstream_callbacks_->onEvent(Network::ConnectionEvent::RemoteClose);
destroyRouter();
}

TEST_F(ThriftRouterTest, UpstreamLocalCloseMidResponse) {
initializeRouter();
startRequest(MessageType::Call);
connectUpstream();

EXPECT_CALL(callbacks_, sendLocalReply(_))
.WillOnce(Invoke([&](const DirectResponse& response) -> void {
auto& app_ex = dynamic_cast<const AppException&>(response);
EXPECT_EQ(AppExceptionType::InternalError, app_ex.type_);
EXPECT_THAT(app_ex.what(), ContainsRegex(".*connection failure.*"));
}));
upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose);
destroyRouter();
}

TEST_F(ThriftRouterTest, UpstreamCloseAfterResponse) {
initializeRouter();
startRequest(MessageType::Call);
connectUpstream();
sendTrivialStruct(FieldType::String);
completeRequest();

upstream_callbacks_->onEvent(Network::ConnectionEvent::LocalClose);
destroyRouter();
}

TEST_F(ThriftRouterTest, UpstreamDataTriggersReset) {
initializeRouter();
startRequest(MessageType::Call);
Expand Down Expand Up @@ -476,6 +579,9 @@ TEST_F(ThriftRouterTest, UnexpectedUpstreamLocalClose) {
TEST_F(ThriftRouterTest, UnexpectedRouterDestroyBeforeUpstreamConnect) {
initializeRouter();
startRequest(MessageType::Call);

EXPECT_EQ(1, context_.cluster_manager_.tcp_conn_pool_.handles_.size());
EXPECT_CALL(context_.cluster_manager_.tcp_conn_pool_.handles_.front(), cancel());
destroyRouter();
}

Expand Down Expand Up @@ -510,6 +616,15 @@ TEST_P(ThriftRouterFieldTypeTest, Call) {
destroyRouter();
}

TEST_F(ThriftRouterTest, CallWithExistingConnection) {
initializeRouter();
startRequestWithExistingConnection(MessageType::Call);
sendTrivialStruct(FieldType::I32);
completeRequest();
returnResponse();
destroyRouter();
}

TEST_P(ThriftRouterContainerTest, DecoderFilterCallbacks) {
FieldType field_type = GetParam();

Expand Down