diff --git a/contrib/postgres_proxy/filters/network/test/postgres_decoder_test.cc b/contrib/postgres_proxy/filters/network/test/postgres_decoder_test.cc index 638e6d9ae5c68..edf19b4933ab8 100644 --- a/contrib/postgres_proxy/filters/network/test/postgres_decoder_test.cc +++ b/contrib/postgres_proxy/filters/network/test/postgres_decoder_test.cc @@ -646,6 +646,8 @@ class FakeBuffer : public Buffer::Instance { MOCK_METHOD(void, prepend, (absl::string_view), (override)); MOCK_METHOD(void, prepend, (Instance&), (override)); MOCK_METHOD(void, copyOut, (size_t, uint64_t, void*), (const, override)); + MOCK_METHOD(uint64_t, copyOutToSlices, + (uint64_t size, Buffer::RawSlice* slices, uint64_t num_slice), (const, override)); MOCK_METHOD(void, drain, (uint64_t), (override)); MOCK_METHOD(Buffer::RawSliceVector, getRawSlices, (absl::optional), (const, override)); MOCK_METHOD(Buffer::RawSlice, frontSlice, (), (const, override)); diff --git a/envoy/buffer/buffer.h b/envoy/buffer/buffer.h index ba28c655a609a..504e3f890b699 100644 --- a/envoy/buffer/buffer.h +++ b/envoy/buffer/buffer.h @@ -201,6 +201,16 @@ class Instance { */ virtual void copyOut(size_t start, uint64_t size, void* data) const PURE; + /** + * Copy out a section of the buffer to dynamic array of slices. + * @param size supplies the size of the data that will be copied. + * @param slices supplies the output slices to fill. + * @param num_slice supplies the number of slices to fill. + * @return the number of bytes copied. + */ + virtual uint64_t copyOutToSlices(uint64_t size, Buffer::RawSlice* slices, + uint64_t num_slice) const PURE; + /** * Drain data from the buffer. * @param size supplies the length of data to drain. diff --git a/source/common/buffer/buffer_impl.cc b/source/common/buffer/buffer_impl.cc index 8c24ea4ff4ba9..a41493e7a51de 100644 --- a/source/common/buffer/buffer_impl.cc +++ b/source/common/buffer/buffer_impl.cc @@ -118,6 +118,43 @@ void OwnedImpl::copyOut(size_t start, uint64_t size, void* data) const { ASSERT(size == 0); } +uint64_t OwnedImpl::copyOutToSlices(uint64_t size, Buffer::RawSlice* dest_slices, + uint64_t num_slice) const { + uint64_t total_length_to_read = std::min(size, this->length()); + uint64_t num_bytes_read = 0; + uint64_t num_dest_slices_read = 0; + uint64_t num_src_slices_read = 0; + uint64_t dest_slice_offset = 0; + uint64_t src_slice_offset = 0; + while (num_dest_slices_read < num_slice && num_bytes_read < total_length_to_read) { + const Slice& src_slice = slices_[num_src_slices_read]; + const Buffer::RawSlice& dest_slice = dest_slices[num_dest_slices_read]; + uint64_t left_to_read = total_length_to_read - num_bytes_read; + uint64_t left_data_size_in_dst_slice = dest_slice.len_ - dest_slice_offset; + uint64_t left_data_size_in_src_slice = src_slice.dataSize() - src_slice_offset; + // The length to copy should be size of smallest in the source slice available size and + // the dest slice available size. + uint64_t length_to_copy = + std::min(left_data_size_in_src_slice, std::min(left_data_size_in_dst_slice, left_to_read)); + memcpy(static_cast(dest_slice.mem_) + dest_slice_offset, // NOLINT(safe-memcpy) + src_slice.data() + src_slice_offset, length_to_copy); + src_slice_offset = src_slice_offset + length_to_copy; + dest_slice_offset = dest_slice_offset + length_to_copy; + if (src_slice_offset == src_slice.dataSize()) { + num_src_slices_read++; + src_slice_offset = 0; + } + if (dest_slice_offset == dest_slice.len_) { + num_dest_slices_read++; + dest_slice_offset = 0; + } + ASSERT(src_slice_offset <= src_slice.dataSize()); + ASSERT(dest_slice_offset <= dest_slice.len_); + num_bytes_read += length_to_copy; + } + return num_bytes_read; +} + void OwnedImpl::drain(uint64_t size) { drainImpl(size); } void OwnedImpl::drainImpl(uint64_t size) { diff --git a/source/common/buffer/buffer_impl.h b/source/common/buffer/buffer_impl.h index bc8aa4ada4279..f74dfb54931f8 100644 --- a/source/common/buffer/buffer_impl.h +++ b/source/common/buffer/buffer_impl.h @@ -687,6 +687,8 @@ class OwnedImpl : public LibEventInstance { void prepend(absl::string_view data) override; void prepend(Instance& data) override; void copyOut(size_t start, uint64_t size, void* data) const override; + uint64_t copyOutToSlices(uint64_t size, Buffer::RawSlice* slices, + uint64_t num_slice) const override; void drain(uint64_t size) override; RawSliceVector getRawSlices(absl::optional max_slices = absl::nullopt) const override; RawSlice frontSlice() const override; diff --git a/source/common/network/BUILD b/source/common/network/BUILD index 1f352e7e59638..07e7c3783814b 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -199,6 +199,7 @@ envoy_cc_library( "//envoy/event:dispatcher_interface", "//envoy/network:io_handle_interface", "//source/common/api:os_sys_calls_lib", + "//source/common/buffer:buffer_lib", "//source/common/event:dispatcher_includes", "@envoy_api//envoy/extensions/network/socket_interface/v3:pkg_cc_proto", ], diff --git a/source/common/network/win32_socket_handle_impl.cc b/source/common/network/win32_socket_handle_impl.cc index c74f5144691b7..986562d339ba6 100644 --- a/source/common/network/win32_socket_handle_impl.cc +++ b/source/common/network/win32_socket_handle_impl.cc @@ -18,6 +18,10 @@ namespace Network { Api::IoCallUint64Result Win32SocketHandleImpl::readv(uint64_t max_length, Buffer::RawSlice* slices, uint64_t num_slice) { + if (peek_buffer_.length() != 0) { + return readvFromPeekBuffer(max_length, slices, num_slice); + } + auto result = IoSocketHandleImpl::readv(max_length, slices, num_slice); reEnableEventBasedOnIOResult(result, Event::FileReadyType::Read); return result; @@ -25,6 +29,10 @@ Api::IoCallUint64Result Win32SocketHandleImpl::readv(uint64_t max_length, Buffer Api::IoCallUint64Result Win32SocketHandleImpl::read(Buffer::Instance& buffer, absl::optional max_length_opt) { + if (peek_buffer_.length() != 0) { + return readFromPeekBuffer(buffer, max_length_opt.value_or(UINT64_MAX)); + } + auto result = IoSocketHandleImpl::read(buffer, max_length_opt); reEnableEventBasedOnIOResult(result, Event::FileReadyType::Read); return result; @@ -71,10 +79,41 @@ Api::IoCallUint64Result Win32SocketHandleImpl::recvmmsg(RawSliceArrays& slices, } Api::IoCallUint64Result Win32SocketHandleImpl::recv(void* buffer, size_t length, int flags) { + if (flags & MSG_PEEK) { + return emulatePeek(buffer, length); + } - Api::IoCallUint64Result result = IoSocketHandleImpl::recv(buffer, length, flags); - reEnableEventBasedOnIOResult(result, Event::FileReadyType::Read); - return result; + if (peek_buffer_.length() == 0) { + Api::IoCallUint64Result result = IoSocketHandleImpl::recv(buffer, length, flags); + reEnableEventBasedOnIOResult(result, Event::FileReadyType::Read); + return result; + } else { + return readFromPeekBuffer(buffer, length); + } +} + +Api::IoCallUint64Result Win32SocketHandleImpl::emulatePeek(void* buffer, size_t length) { + // If there's not enough data in the peek buffer, try reading more. + if (length > peek_buffer_.length()) { + // The caller is responsible for calling with the larger size + // in cases it needs to do so it can't rely on transparent event activation. + // So in this case we should activate read again unless the read blocked. + Api::IoCallUint64Result peek_result = drainToPeekBuffer(length); + + // Some error happened. + if (!peek_result.ok()) { + if (peek_result.wouldBlock() && file_event_) { + file_event_->registerEventIfEmulatedEdge(Event::FileReadyType::Read); + if (peek_buffer_.length() == 0) { + return peek_result; + } + } else { + return peek_result; + } + } + } + + return peekFromPeekBuffer(buffer, length); } void Win32SocketHandleImpl::reEnableEventBasedOnIOResult(const Api::IoCallUint64Result& result, @@ -84,5 +123,69 @@ void Win32SocketHandleImpl::reEnableEventBasedOnIOResult(const Api::IoCallUint64 } } +Api::IoCallUint64Result Win32SocketHandleImpl::drainToPeekBuffer(size_t length) { + size_t total_bytes_read = 0; + while (peek_buffer_.length() < length) { + Buffer::Reservation reservation = peek_buffer_.reserveForRead(); + uint64_t bytes_to_read = std::min( + static_cast(length - peek_buffer_.length()), reservation.length()); + Api::IoCallUint64Result result = + IoSocketHandleImpl::readv(bytes_to_read, reservation.slices(), reservation.numSlices()); + uint64_t bytes_to_commit = result.ok() ? result.return_value_ : 0; + reservation.commit(bytes_to_commit); + total_bytes_read += bytes_to_commit; + if (!result.ok() || bytes_to_commit == 0) { + return result; + } + } + return Api::IoCallUint64Result(total_bytes_read, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); +} + +Api::IoCallUint64Result Win32SocketHandleImpl::readFromPeekBuffer(void* buffer, size_t length) { + uint64_t copy_size = std::min(peek_buffer_.length(), static_cast(length)); + peek_buffer_.copyOut(0, copy_size, buffer); + peek_buffer_.drain(copy_size); + return Api::IoCallUint64Result(copy_size, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); +} + +Api::IoCallUint64Result Win32SocketHandleImpl::readvFromPeekBuffer(uint64_t max_length, + Buffer::RawSlice* slices, + uint64_t num_slice) { + uint64_t bytes_read = peek_buffer_.copyOutToSlices(max_length, slices, num_slice); + peek_buffer_.drain(bytes_read); + return Api::IoCallUint64Result(bytes_read, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); +} + +Api::IoCallUint64Result Win32SocketHandleImpl::readFromPeekBuffer(Buffer::Instance& buffer, + size_t length) { + auto lenght_to_move = std::min(peek_buffer_.length(), static_cast(length)); + buffer.move(peek_buffer_, lenght_to_move); + return Api::IoCallUint64Result(lenght_to_move, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); +} + +Api::IoCallUint64Result Win32SocketHandleImpl::peekFromPeekBuffer(void* buffer, size_t length) { + uint64_t copy_size = std::min(peek_buffer_.length(), static_cast(length)); + peek_buffer_.copyOut(0, copy_size, buffer); + return Api::IoCallUint64Result(copy_size, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); +} + +void Win32SocketHandleImpl::initializeFileEvent(Event::Dispatcher& dispatcher, + Event::FileReadyCb cb, + Event::FileTriggerType trigger, uint32_t events) { + IoSocketHandleImpl::initializeFileEvent(dispatcher, cb, trigger, events); + // Activate the file event directly when we have the data in the peek_buffer_. + if ((events & Event::FileReadyType::Read) && peek_buffer_.length() > 0) { + activateFileEvents(Event::FileReadyType::Read); + } +} + +void Win32SocketHandleImpl::enableFileEvents(uint32_t events) { + IoSocketHandleImpl::enableFileEvents(events); + // Activate the file event directly when we have the data in the peek_buffer_. + if ((events & Event::FileReadyType::Read) && peek_buffer_.length() > 0) { + activateFileEvents(Event::FileReadyType::Read); + } +} + } // namespace Network } // namespace Envoy diff --git a/source/common/network/win32_socket_handle_impl.h b/source/common/network/win32_socket_handle_impl.h index b9465f71db544..f8835e5aed544 100644 --- a/source/common/network/win32_socket_handle_impl.h +++ b/source/common/network/win32_socket_handle_impl.h @@ -6,6 +6,7 @@ #include "envoy/event/dispatcher.h" #include "envoy/network/io_handle.h" +#include "source/common/buffer/buffer_impl.h" #include "source/common/common/logger.h" #include "source/common/network/io_socket_error_impl.h" #include "source/common/network/io_socket_handle_impl.h" @@ -42,8 +43,37 @@ class Win32SocketHandleImpl : public IoSocketHandleImpl { RecvMsgOutput& output) override; Api::IoCallUint64Result recv(void* buffer, size_t length, int flags) override; + void initializeFileEvent(Event::Dispatcher& dispatcher, Event::FileReadyCb cb, + Event::FileTriggerType trigger, uint32_t events) override; + void enableFileEvents(uint32_t events) override; + private: void reEnableEventBasedOnIOResult(const Api::IoCallUint64Result& result, uint32_t event); + + // On Windows we use the MSG_PEEK on recv instead of peeking the socket + // we drain the socket to memory. Subsequent read calls need to read + // first from the class buffer and then go to the underlying socket. + + // Implement the peek logic of recv for readability purposes + Api::IoCallUint64Result emulatePeek(void* buffer, size_t length); + + /** + * Drain the socket into `peek_buffer_`. + * @param length is the desired length of data drained into the `peek_buffer_`. + * @return the actual length of data drained into the `peek_buffer_`. + */ + Api::IoCallUint64Result drainToPeekBuffer(size_t length); + + // Useful functions to read from the peek buffer based on + // the signatures of readv/read/recv OS socket functions. + Api::IoCallUint64Result readFromPeekBuffer(void* buffer, size_t length); + Api::IoCallUint64Result readFromPeekBuffer(Buffer::Instance& buffer, size_t length); + Api::IoCallUint64Result readvFromPeekBuffer(uint64_t max_length, Buffer::RawSlice* slices, + uint64_t num_slice); + Api::IoCallUint64Result peekFromPeekBuffer(void* buffer, size_t length); + + // For windows mimic MSG_PEEK + Buffer::OwnedImpl peek_buffer_; }; } // namespace Network } // namespace Envoy diff --git a/test/common/buffer/buffer_corpus/basic b/test/common/buffer/buffer_corpus/basic index 9fd31255e2e63..9f32b6d0bc31b 100644 --- a/test/common/buffer/buffer_corpus/basic +++ b/test/common/buffer/buffer_corpus/basic @@ -27,6 +27,9 @@ actions { length: 200 } } +actions { + copy_out_to_slices: 200 +} actions { drain: 98 } diff --git a/test/common/buffer/buffer_fuzz.cc b/test/common/buffer/buffer_fuzz.cc index 5fe4309c0bbb2..918c4e0d493ea 100644 --- a/test/common/buffer/buffer_fuzz.cc +++ b/test/common/buffer/buffer_fuzz.cc @@ -114,6 +114,21 @@ class StringBuffer : public Buffer::Instance { ::memcpy(data, this->start() + start, size); } + uint64_t copyOutToSlices(uint64_t length, Buffer::RawSlice* slices, + uint64_t num_slices) const override { + uint64_t size_copied = 0; + uint64_t num_slices_copied = 0; + while (size_copied < length && num_slices_copied < num_slices) { + auto copy_length = std::min((length - size_copied), slices[num_slices_copied].len_); + ::memcpy(slices[num_slices_copied].mem_, this->start(), copy_length); + size_copied += copy_length; + if (copy_length == slices[num_slices_copied].len_) { + num_slices_copied++; + } + } + return size_copied; + } + void drain(uint64_t size) override { FUZZ_ASSERT(size <= size_); start_ += size; @@ -318,6 +333,18 @@ uint32_t bufferAction(Context& ctxt, char insert_value, uint32_t max_alloc, Buff FUZZ_ASSERT(::memcmp(copy_buffer, data.data() + start, length) == 0); break; } + case test::common::buffer::Action::kCopyOutToSlices: { + const uint32_t length = + std::min(static_cast(target_buffer.length()), action.copy_out_to_slices()); + Buffer::OwnedImpl buffer; + auto reservation = buffer.reserveForRead(); + auto rc = target_buffer.copyOutToSlices(length, reservation.slices(), reservation.numSlices()); + reservation.commit(rc); + const std::string data = buffer.toString(); + const std::string target_data = target_buffer.toString(); + FUZZ_ASSERT(::memcmp(data.data(), target_data.data(), reservation.length()) == 0); + break; + } case test::common::buffer::Action::kDrain: { const uint32_t previous_length = target_buffer.length(); const uint32_t drain_length = diff --git a/test/common/buffer/buffer_fuzz.proto b/test/common/buffer/buffer_fuzz.proto index a4a18cc100c50..91a43f5d33e66 100644 --- a/test/common/buffer/buffer_fuzz.proto +++ b/test/common/buffer/buffer_fuzz.proto @@ -42,6 +42,7 @@ message Action { uint32 get_raw_slices = 14; Search search = 15; string starts_with = 16; + uint32 copy_out_to_slices = 17; } } diff --git a/test/common/buffer/owned_impl_test.cc b/test/common/buffer/owned_impl_test.cc index 7a4adc7d3058c..acbf598f2f425 100644 --- a/test/common/buffer/owned_impl_test.cc +++ b/test/common/buffer/owned_impl_test.cc @@ -1,4 +1,5 @@ #include +#include #include "envoy/api/io_error.h" @@ -1080,6 +1081,126 @@ void TestBufferMove(uint64_t buffer1_length, uint64_t buffer2_length, EXPECT_EQ(0, buffer2.length()); } +TEST_F(OwnedImplTest, CopyOutToSlicesTests) { + std::string data = "Hello, World!"; + Buffer::OwnedImpl buffer; + buffer.prepend(data); + + EXPECT_EQ(data.size(), buffer.length()); + EXPECT_EQ(data, buffer.toString()); + + { + Buffer::OwnedImpl buf; + auto reservation = buf.reserveSingleSlice(1024); + auto slice = reservation.slice(); + EXPECT_EQ(data.size(), buffer.copyOutToSlices(100, &slice, 1)); + reservation.commit(data.size()); + EXPECT_EQ(data, buffer.toString()); + } + + { + Buffer::OwnedImpl buf; + auto reservation = buf.reserveSingleSlice(5); + auto slice = reservation.slice(); + EXPECT_EQ(5, buffer.copyOutToSlices(100, &slice, 1)); + reservation.commit(5); + EXPECT_EQ("Hello", buf.toString()); + } + + { + Buffer::OwnedImpl buf; + auto reservation = buf.reserveForRead(); + EXPECT_EQ(5, buffer.copyOutToSlices(5, reservation.slices(), reservation.numSlices())); + reservation.commit(5); + EXPECT_EQ("Hello", buf.toString()); + } + + { + Buffer::OwnedImpl buf; + auto reservation = buf.reserveForRead(); + EXPECT_EQ(data.size(), + buffer.copyOutToSlices(100, reservation.slices(), reservation.numSlices())); + reservation.commit(data.size()); + EXPECT_EQ(data, buf.toString()); + } + // Test the destination buffer has smaller slice than the source buffer. + { + Buffer::OwnedImpl src_buf; + std::string data; + for (auto i = 0; i < (32 * 1024); i++) { + data.append(std::to_string(i % 10)); + } + // Build the source buffer to have a single 32KB slice. + src_buf.appendSliceForTest(data); + EXPECT_EQ(1, src_buf.getRawSlices().size()); + EXPECT_EQ(32 * 1024, src_buf.frontSlice().len_); + + Buffer::OwnedImpl dest_buf; + // The destination buffer are expected to have 8 Slices, each slice has 16KB buffer. + auto reservation = dest_buf.reserveForRead(); + EXPECT_EQ(8, reservation.numSlices()); + for (uint64_t i = 0; i < reservation.numSlices(); i++) { + EXPECT_EQ(16 * 1024, reservation.slices()[i].len_); + } + + // Copy single 32 KB slice's data to 8 * 16KB slices. + EXPECT_EQ(data.size(), + src_buf.copyOutToSlices(32 * 1024, reservation.slices(), reservation.numSlices())); + reservation.commit(data.size()); + EXPECT_EQ(data, dest_buf.toString()); + } + // Test the source buffer has smaller slice than the destination buffer. + { + Buffer::OwnedImpl src_buf; + // Build the source buffer to have 7 slices. + src_buf.appendSliceForTest("He", 2); + src_buf.appendSliceForTest("ll", 2); + src_buf.appendSliceForTest("o,", 2); + src_buf.appendSliceForTest(" W", 2); + src_buf.appendSliceForTest("or", 2); + src_buf.appendSliceForTest("ld", 2); + src_buf.appendSliceForTest("!", 1); + Buffer::OwnedImpl dest_buf; + // The destination buffer are expected to have 8 Slices, each slice has 16KB buffer. + auto reservation = dest_buf.reserveForRead(); + EXPECT_EQ(8, reservation.numSlices()); + for (uint64_t i = 0; i < reservation.numSlices(); i++) { + EXPECT_EQ(16 * 1024, reservation.slices()[i].len_); + } + + // Copy data from src 7 slices into the first 16K slice of dest. + EXPECT_EQ(data.size(), + src_buf.copyOutToSlices(100, reservation.slices(), reservation.numSlices())); + reservation.commit(data.size()); + EXPECT_EQ(data, dest_buf.toString()); + } + { + Buffer::OwnedImpl src_buffer; + // Create a slice with a small amount of data. + const uint32_t small_data_size = 10; + std::string small_data = std::string(small_data_size, 'a'); + src_buffer.prepend(small_data); + + // Add another slice with a large amount of data. + const uint32_t large_data_size = 16384; + std::string large_data = std::string(large_data_size, 'b'); + BufferFragmentImpl frag(large_data.data(), large_data.size(), nullptr); + src_buffer.addBufferFragment(frag); + EXPECT_EQ(small_data_size + large_data_size, src_buffer.length()); + + // Copy-out from the buffer. + Buffer::OwnedImpl dest_buf; + auto reservation = dest_buf.reserveForRead(); + EXPECT_EQ(small_data_size + large_data_size, + src_buffer.copyOutToSlices(small_data_size + large_data_size, reservation.slices(), + reservation.numSlices())); + reservation.commit(small_data_size + large_data_size); + EXPECT_EQ(absl::StrCat(small_data, large_data), dest_buf.toString()); + + src_buffer.drain(small_data_size + large_data_size); + } +} + // Slice size large enough to prevent slice content from being coalesced into an existing slice constexpr uint64_t kLargeSliceSize = 2048; diff --git a/test/common/network/BUILD b/test/common/network/BUILD index f2390df3ceced..0afaf9f64096a 100644 --- a/test/common/network/BUILD +++ b/test/common/network/BUILD @@ -408,6 +408,19 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "win32_socket_handle_impl_test", + srcs = ["win32_socket_handle_impl_test.cc"], + deps = [ + "//source/common/buffer:buffer_lib", + "//source/common/common:utility_lib", + "//source/common/network:address_lib", + "//test/mocks/api:api_mocks", + "//test/mocks/event:event_mocks", + "//test/test_common:threadsafe_singleton_injector_lib", + ], +) + envoy_cc_test( name = "io_socket_handle_impl_integration_test", srcs = ["io_socket_handle_impl_integration_test.cc"], diff --git a/test/common/network/win32_socket_handle_impl_test.cc b/test/common/network/win32_socket_handle_impl_test.cc new file mode 100644 index 0000000000000..c06880d9f199b --- /dev/null +++ b/test/common/network/win32_socket_handle_impl_test.cc @@ -0,0 +1,190 @@ +#include "source/common/common/utility.h" +#include "source/common/network/address_impl.h" +#include "source/common/network/io_socket_error_impl.h" +#include "source/common/network/io_socket_handle_impl.h" +#include "source/common/network/listen_socket_impl.h" + +#include "test/mocks/api/mocks.h" +#include "test/mocks/event/mocks.h" +#include "test/test_common/environment.h" +#include "test/test_common/network_utility.h" +#include "test/test_common/threadsafe_singleton_injector.h" + +#include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::_; +using testing::Invoke; +using testing::NiceMock; +using testing::Return; + +namespace Envoy { +namespace Network { + +class Win32SocketHandleImplTest : public testing::Test { +public: + Win32SocketHandleImplTest() : io_handle_(42) { + dispatcher_ = std::make_unique>(); + file_event_ = new NiceMock; + EXPECT_CALL(*dispatcher_, createFileEvent_(42, _, _, _)).WillOnce(Return(file_event_)); + io_handle_.setBlocking(false); + io_handle_.initializeFileEvent( + *dispatcher_, [](uint32_t) -> void {}, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read | Event::FileReadyType::Closed); + } + +protected: + std::unique_ptr> dispatcher_; + NiceMock* file_event_; + Network::Win32SocketHandleImpl io_handle_; +}; + +TEST_F(Win32SocketHandleImplTest, ReadvWithNoBufferShouldReadFromTheWire) { + + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(1) + .WillRepeatedly(Return(Api::SysCallSizeResult{10, 0})); + + Buffer::OwnedImpl read_buffer; + Buffer::Reservation reservation = read_buffer.reserveForRead(); + auto rc = io_handle_.readv(reservation.length(), reservation.slices(), reservation.numSlices()); + EXPECT_EQ(rc.return_value_, 10); +} + +TEST_F(Win32SocketHandleImplTest, ReadvShouldReenableEventsOnBlock) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(1) + .WillRepeatedly(Return(Api::SysCallSizeResult{-1, SOCKET_ERROR_AGAIN})); + + EXPECT_CALL(*file_event_, registerEventIfEmulatedEdge(_)); + Buffer::OwnedImpl read_buffer; + Buffer::Reservation reservation = read_buffer.reserveForRead(); + auto rc = io_handle_.readv(reservation.length(), reservation.slices(), reservation.numSlices()); + EXPECT_EQ(rc.return_value_, 0); + EXPECT_EQ(rc.err_->getErrorCode(), Api::IoError::IoErrorCode::Again); +} + +TEST_F(Win32SocketHandleImplTest, ReadvWithBufferShouldReadFromBuffer) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + constexpr int data_length = 10; + std::string data(data_length, '*'); + EXPECT_CALL(os_sys_calls, readv(_, _, _)).WillOnce(Invoke([&](os_fd_t, const iovec* iov, int) { + memcpy(iov->iov_base, data.data(), data_length); // NOLINT(safe-memcpy) + return Api::SysCallSizeResult{data_length, 0}; + })); + + absl::FixedArray buf(data_length); + auto rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.return_value_, data_length); + EXPECT_EQ(data, std::string(buf.data(), buf.size())); + Buffer::OwnedImpl read_buffer; + Buffer::Reservation reservation = read_buffer.reserveForRead(); + rc = io_handle_.readv(reservation.length(), reservation.slices(), reservation.numSlices()); + EXPECT_EQ(rc.return_value_, 10); + reservation.commit(rc.return_value_); + EXPECT_EQ(data, read_buffer.toString()); +} + +TEST_F(Win32SocketHandleImplTest, RecvWithoutPeekShouldReadFromWire) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(1) + .WillRepeatedly(Return(Api::SysCallSizeResult{10, 0})); + + absl::FixedArray buf(10); + auto rc = io_handle_.recv(buf.data(), buf.size(), 0); + EXPECT_EQ(rc.return_value_, 10); +} + +TEST_F(Win32SocketHandleImplTest, RecvWithPeekMultipleTimes) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .WillOnce(Invoke([&](os_fd_t, const iovec* iov, int num_iovs) { + size_t size_to_read = 0; + for (auto i = 0; i < num_iovs; i++) { + size_to_read += iov[i].iov_len; + } + EXPECT_EQ(10, size_to_read); + return Api::SysCallSizeResult{5, 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{-1, SOCKET_ERROR_AGAIN})); + + EXPECT_CALL(*file_event_, registerEventIfEmulatedEdge(_)); + absl::FixedArray buf(10); + auto rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.return_value_, 5); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .WillOnce(Invoke([&](os_fd_t, const iovec* iov, int num_iovs) { + size_t size_to_read = 0; + for (auto i = 0; i < num_iovs; i++) { + size_to_read += iov[i].iov_len; + } + EXPECT_EQ(5, size_to_read); + return Api::SysCallSizeResult{5, 0}; + })); + auto rc2 = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc2.return_value_, 10); +} + +TEST_F(Win32SocketHandleImplTest, RecvWithPeekReactivatesReadOnBlock) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(1) + .WillOnce(Return(Api::SysCallSizeResult{-1, SOCKET_ERROR_AGAIN})); + + EXPECT_CALL(*file_event_, registerEventIfEmulatedEdge(_)); + absl::FixedArray buf(10); + auto rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.err_->getErrorCode(), Api::IoError::IoErrorCode::Again); +} + +TEST_F(Win32SocketHandleImplTest, RecvWithPeekFlagReturnsFinalError) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + constexpr int data_length = 10; + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(2) + .WillOnce(Invoke([&](os_fd_t, const iovec*, int) { + return Api::SysCallSizeResult{data_length / 2, 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{-1, SOCKET_ERROR_CONNRESET})); + + absl::FixedArray buf(data_length); + auto rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.err_->getErrorCode(), Api::IoError::IoErrorCode::ConnectionReset); +} + +TEST_F(Win32SocketHandleImplTest, ReadvWithPeekShouldReadFromBuffer) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + constexpr int data_length = 10; + std::string data(data_length, '*'); + EXPECT_CALL(os_sys_calls, readv(_, _, _)).WillOnce(Invoke([&](os_fd_t, const iovec* iov, int) { + memcpy(iov->iov_base, data.data(), data_length); // NOLINT(safe-memcpy) + return Api::SysCallSizeResult{data_length, 0}; + })); + + absl::FixedArray buf(data_length); + auto rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.return_value_, data_length); + EXPECT_EQ(data, std::string(buf.data(), buf.size())); + // Second call should not make a system call, it should + // read from memory. + rc = io_handle_.recv(buf.data(), buf.size(), MSG_PEEK); + EXPECT_EQ(rc.return_value_, data_length); + EXPECT_EQ(data, std::string(buf.data(), buf.size())); +} + +} // namespace Network +} // namespace Envoy diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc index 7f99111dc1cbc..ecece6746e42d 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc @@ -316,25 +316,31 @@ TEST_P(ProxyProtocolTest, ErrorRecv_2) { Api::MockOsSysCalls os_sys_calls; TestThreadsafeSingletonInjector os_calls(&os_sys_calls); - // TODO(davinci26): Mocking should not be used to provide real system calls. +// TODO(davinci26): Mocking should not be used to provide real system calls. +#ifdef WIN32 + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); +#else + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); +#endif EXPECT_CALL(os_sys_calls, connect(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { return os_sys_calls_actual_.connect(sockfd, addr, addrlen); })); - EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) - .Times(AnyNumber()) - .WillOnce(Return(Api::SysCallSizeResult{-1, 0})); EXPECT_CALL(os_sys_calls, writev(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { return os_sys_calls_actual_.writev(fd, iov, iovcnt); })); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) - .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { - return os_sys_calls_actual_.readv(fd, iov, iovcnt); - })); EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke( @@ -387,9 +393,20 @@ TEST_P(ProxyProtocolTest, ErrorRecv_1) { TestThreadsafeSingletonInjector os_calls(&os_sys_calls); // TODO(davinci26): Mocking should not be used to provide real system calls. +#ifdef WIN32 + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); +#else + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); +#endif EXPECT_CALL(os_sys_calls, connect(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { @@ -400,11 +417,6 @@ TEST_P(ProxyProtocolTest, ErrorRecv_1) { .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { return os_sys_calls_actual_.writev(fd, iov, iovcnt); })); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) - .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { - return os_sys_calls_actual_.readv(fd, iov, iovcnt); - })); EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke( @@ -813,6 +825,15 @@ TEST_P(ProxyProtocolTest, V2Fragmented4Error) { TestThreadsafeSingletonInjector os_calls(&os_sys_calls); // TODO(davinci26): Mocking should not be used to provide real system calls. +#ifdef WIN32 + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillOnce(Invoke([&](os_fd_t fd, const iovec* iov, int num_iov) { + const Api::SysCallSizeResult x = os_sys_calls_actual_.readv(fd, iov, num_iov); + return x; + })) + .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); +#else EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t fd, void* buf, size_t len, int flags) { @@ -821,6 +842,13 @@ TEST_P(ProxyProtocolTest, V2Fragmented4Error) { EXPECT_CALL(os_sys_calls, recv(_, _, 1, _)) .Times(AnyNumber()) .WillOnce(Return(Api::SysCallSizeResult{-1, 0})); + + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); +#endif EXPECT_CALL(os_sys_calls, connect(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { @@ -831,11 +859,6 @@ TEST_P(ProxyProtocolTest, V2Fragmented4Error) { .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { return os_sys_calls_actual_.writev(fd, iov, iovcnt); })); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) - .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { - return os_sys_calls_actual_.readv(fd, iov, iovcnt); - })); EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke( @@ -889,6 +912,18 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { TestThreadsafeSingletonInjector os_calls(&os_sys_calls); // TODO(davinci26): Mocking should not be used to provide real system calls. +#ifdef WIN32 + bool partial_write = false; + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([&](os_fd_t fd, const iovec* iov, int num_iov) { + if (partial_write) { + return Api::SysCallSizeResult{-1, 0}; + } + const Api::SysCallSizeResult x = os_sys_calls_actual_.readv(fd, iov, num_iov); + return x; + })); +#else EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t fd, void* buf, size_t len, int flags) { @@ -897,6 +932,12 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { EXPECT_CALL(os_sys_calls, recv(_, _, 4, _)) .Times(AnyNumber()) .WillOnce(Return(Api::SysCallSizeResult{-1, 0})); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); +#endif EXPECT_CALL(os_sys_calls, connect(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { @@ -907,11 +948,6 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { return os_sys_calls_actual_.writev(fd, iov, iovcnt); })); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) - .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { - return os_sys_calls_actual_.readv(fd, iov, iovcnt); - })); EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke( @@ -950,6 +986,9 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { connect(false); write(buffer, 10); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); +#ifdef WIN32 + partial_write = true; +#endif write(buffer + 10, 10); expectProxyProtoError();