Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 15 additions & 19 deletions source/common/buffer/buffer_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,15 @@

namespace Buffer {

// RawSlice is the same structure as evbuffer_iovec. This was put into place to avoid leaking
// libevent into most code since we will likely replace evbuffer with our own implementation at
// some point. However, we can avoid a bunch of copies since the structure is the same.
static_assert(sizeof(RawSlice) == sizeof(evbuffer_iovec), "RawSlice != evbuffer_iovec");
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nice cleanup.

static_assert(offsetof(RawSlice, mem_) == offsetof(evbuffer_iovec, iov_base),
"RawSlice != evbuffer_iovec");
static_assert(offsetof(RawSlice, len_) == offsetof(evbuffer_iovec, iov_len),
"RawSlice != evbuffer_iovec");

void OwnedImpl::add(const void* data, uint64_t size) { evbuffer_add(buffer_.get(), data, size); }

void OwnedImpl::add(const std::string& data) {
Expand All @@ -22,12 +31,8 @@ void OwnedImpl::add(const Instance& data) {
}

void OwnedImpl::commit(RawSlice* iovecs, uint64_t num_iovecs) {
evbuffer_iovec local_iovecs[num_iovecs];
for (uint64_t i = 0; i < num_iovecs; i++) {
local_iovecs[i].iov_len = iovecs[i].len_;
local_iovecs[i].iov_base = iovecs[i].mem_;
}
int rc = evbuffer_commit_space(buffer_.get(), local_iovecs, num_iovecs);
int rc =
evbuffer_commit_space(buffer_.get(), reinterpret_cast<evbuffer_iovec*>(iovecs), num_iovecs);
ASSERT(rc == 0);
UNREFERENCED_PARAMETER(rc);
}
Expand All @@ -40,13 +45,8 @@ void OwnedImpl::drain(uint64_t size) {
}

uint64_t OwnedImpl::getRawSlices(RawSlice* out, uint64_t out_size) const {
evbuffer_iovec iovecs[out_size];
uint64_t needed_size = evbuffer_peek(buffer_.get(), -1, nullptr, iovecs, out_size);
for (uint64_t i = 0; i < std::min(out_size, needed_size); i++) {
out[i].mem_ = iovecs[i].iov_base;
out[i].len_ = iovecs[i].iov_len;
}
return needed_size;
return evbuffer_peek(buffer_.get(), -1, nullptr, reinterpret_cast<evbuffer_iovec*>(out),
out_size);
}

uint64_t OwnedImpl::length() const { return evbuffer_get_length(buffer_.get()); }
Expand Down Expand Up @@ -79,13 +79,9 @@ int OwnedImpl::read(int fd, uint64_t max_length) {
}

uint64_t OwnedImpl::reserve(uint64_t length, RawSlice* iovecs, uint64_t num_iovecs) {
evbuffer_iovec local_iovecs[num_iovecs];
uint64_t ret = evbuffer_reserve_space(buffer_.get(), length, local_iovecs, num_iovecs);
uint64_t ret = evbuffer_reserve_space(buffer_.get(), length,
reinterpret_cast<evbuffer_iovec*>(iovecs), num_iovecs);
ASSERT(ret >= 1);
for (uint64_t i = 0; i < ret; i++) {
iovecs[i].len_ = local_iovecs[i].iov_len;
iovecs[i].mem_ = local_iovecs[i].iov_base;
}
return ret;
}

Expand Down
6 changes: 6 additions & 0 deletions source/common/network/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -231,6 +231,12 @@ void ConnectionImpl::write(Buffer::Instance& data) {

if (data.length() > 0) {
conn_log_trace("writing {} bytes", *this, data.length());
// TODO(mattklein123): All data currently gets moved from the source buffer to the write buffer.
// This can lead to inefficient behavior if writing a bunch of small chunks. In this case, it
// would likely be more efficient to copy data below a certain size. VERY IMPORTANT: If this is
// ever changed, read the comment in Ssl::ConnectionImpl::doWriteToSocket() VERY carefully.
// That code assumes that we never change existing write_buffer_ chain elements between calls
// to SSL_write(). That code will have to change if we ever copy here.
write_buffer_.move(data);
if (!(state_ & InternalState::Connecting)) {
file_event_->activate(Event::FileReadyType::Write);
Expand Down
79 changes: 45 additions & 34 deletions source/common/ssl/connection_impl.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,47 +134,58 @@ Network::ConnectionImpl::IoResult ConnectionImpl::doWriteToSocket() {
}
}

if (write_buffer_.length() == 0) {
return {PostIoAction::KeepOpen, 0};
}
uint64_t total_bytes_written = 0;
bool keep_writing = true;
while ((write_buffer_.length() > 0) && keep_writing) {
// Protect against stack overflow if the buffer has a very large buffer chain.
// TODO(mattklein123): The current evbuffer Buffer::Instance implementation will iterate through
// the entire chain each time this is called to determine how many slices would be needed. In
// this case, we don't care, and only want to fill up to MAX_SLICES. When we swap out evbuffer
// we can change this behavior.
// TODO(mattklein123): As it relates to our fairness efforts, we might want to limit the number
// of iterations of this loop, either by pure iterations, bytes written, etc.
const uint64_t MAX_SLICES = 32;
Buffer::RawSlice slices[MAX_SLICES];
uint64_t num_slices = std::min(MAX_SLICES, write_buffer_.getRawSlices(slices, MAX_SLICES));

uint64_t inner_bytes_written = 0;
for (uint64_t i = 0; i < num_slices; i++) {
// SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call
// it again with the same parameters. Most implementations keep track of the last write size.
// In our case we don't need to do that because: a) SSL_write() will not write partial
// buffers. b) We only move() into the write buffer, which means that it's impossible for a
// particular chain to increase in size. So as long as we start writing where we left off we
// are guaranteed to call SSL_write() with the same parameters.
int rc = SSL_write(ssl_.get(), slices[i].mem_, slices[i].len_);
conn_log_trace("ssl write returns: {}", *this, rc);
if (rc > 0) {
inner_bytes_written += rc;
total_bytes_written += rc;
} else {
int err = SSL_get_error(ssl_.get(), rc);
switch (err) {
case SSL_ERROR_WANT_WRITE:
keep_writing = false;
break;
case SSL_ERROR_WANT_READ:
// Renegotiation has started. We don't handle renegotiation so just fall through.
default:
drainErrorQueue();
return {PostIoAction::Close, total_bytes_written};
}

uint64_t num_slices = write_buffer_.getRawSlices(nullptr, 0);
Buffer::RawSlice slices[num_slices];
write_buffer_.getRawSlices(slices, num_slices);

uint64_t bytes_written = 0;
for (uint64_t i = 0; i < num_slices; i++) {
// SSL_write() requires that if a previous call returns SSL_ERROR_WANT_WRITE, we need to call
// it again with the same parameters. Most implementations keep track of the last write size.
// In our case we don't need to do that because: a) SSL_write() will not write partial buffers.
// b) We only move() into the write buffer, which means that it's impossible for a particular
// chain to increase in size. So as long as we start writing where we left off we are guaranteed
// to call SSL_write() with the same parameters.
int rc = SSL_write(ssl_.get(), slices[i].mem_, slices[i].len_);
conn_log_trace("ssl write returns: {}", *this, rc);
if (rc > 0) {
bytes_written += rc;
} else {
int err = SSL_get_error(ssl_.get(), rc);
switch (err) {
case SSL_ERROR_WANT_WRITE:
break;
case SSL_ERROR_WANT_READ:
// Renegotiation has started. We don't handle renegotiation so just fall through.
default:
drainErrorQueue();
return {PostIoAction::Close, bytes_written};
}

break;
}
}

if (bytes_written > 0) {
write_buffer_.drain(bytes_written);
// Draining must be done within the inner loop, otherwise we will keep getting the same slices
// at the beginning of the buffer.
if (inner_bytes_written > 0) {
write_buffer_.drain(inner_bytes_written);
}
}

return {PostIoAction::KeepOpen, bytes_written};
return {PostIoAction::KeepOpen, total_bytes_written};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

There are other locations where the Buffer::RawSlices slices[num_slices] pattern appears (from a grep), e.g. source/common/http/http1/codec_impl.cc. Naively it seems similar concerns would exist there. Is there an easy way to avoid this stack overflow danger globally?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately there is no global way to do this, since the code currently assumes that it returns all slices, or the size of the slices array required to hold them all. I will put in a TODO to figure this out and clean this up globally. I want to get a fix for this crash out since we saw this in production.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I opened #593 instead of a TODO.

}

void ConnectionImpl::onConnected() { ASSERT(!handshake_complete_); }
Expand Down
28 changes: 18 additions & 10 deletions test/common/ssl/connection_impl_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -247,9 +247,8 @@ TEST(SslConnectionImplTest, SslError) {

class SslReadBufferLimitTest : public testing::Test {
public:
void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size) {
const uint32_t buffer_size = 256 * 1024;

void readBufferLimitTest(uint32_t read_buffer_limit, uint32_t expected_chunk_size,
uint32_t write_size, uint32_t num_writes) {
Stats::IsolatedStoreImpl stats_store;
Event::DispatcherImpl dispatcher;
Network::TcpListenSocket socket(uint32_t(10000), true);
Expand Down Expand Up @@ -306,10 +305,10 @@ class SslReadBufferLimitTest : public testing::Test {
EXPECT_CALL(*read_filter, onNewConnection());
EXPECT_CALL(*read_filter, onData(_))
.WillRepeatedly(Invoke([&](Buffer::Instance& data) -> Network::FilterStatus {
EXPECT_EQ(expected_chunk_size, data.length());
EXPECT_GE(expected_chunk_size, data.length());
filter_seen += data.length();
data.drain(data.length());
if (filter_seen == buffer_size) {
if (filter_seen == (write_size * num_writes)) {
server_connection->close(Network::ConnectionCloseType::FlushWrite);
}
return Network::FilterStatus::StopIteration;
Expand All @@ -320,18 +319,27 @@ class SslReadBufferLimitTest : public testing::Test {
EXPECT_CALL(client_callbacks, onEvent(Network::ConnectionEvent::Connected));
EXPECT_CALL(client_callbacks, onEvent(Network::ConnectionEvent::RemoteClose))
.WillOnce(Invoke([&](uint32_t) -> void {
EXPECT_EQ(buffer_size, filter_seen);
EXPECT_EQ((write_size * num_writes), filter_seen);
dispatcher.exit();
}));

Buffer::OwnedImpl data(std::string(buffer_size, 'a'));
client_connection->write(data);
for (uint32_t i = 0; i < num_writes; i++) {
Buffer::OwnedImpl data(std::string(write_size, 'a'));
client_connection->write(data);
}

dispatcher.run(Event::Dispatcher::RunType::Block);
}
};

TEST_F(SslReadBufferLimitTest, NoLimit) { readBufferLimitTest(0, 256 * 1024); }
TEST_F(SslReadBufferLimitTest, NoLimit) { readBufferLimitTest(0, 256 * 1024, 256 * 1024, 1); }

TEST_F(SslReadBufferLimitTest, SomeLimit) { readBufferLimitTest(32 * 1024, 32 * 1024); }
TEST_F(SslReadBufferLimitTest, NoLimitSmallWrites) {
readBufferLimitTest(0, 256 * 1024, 1, 256 * 1024);
}

TEST_F(SslReadBufferLimitTest, SomeLimit) {
readBufferLimitTest(32 * 1024, 32 * 1024, 256 * 1024, 1);
}

} // Ssl