diff --git a/docs/root/configuration/listeners/listener_filters/tls_inspector.rst b/docs/root/configuration/listeners/listener_filters/tls_inspector.rst index e863d895946c1..65bb3a65e79b5 100644 --- a/docs/root/configuration/listeners/listener_filters/tls_inspector.rst +++ b/docs/root/configuration/listeners/listener_filters/tls_inspector.rst @@ -49,9 +49,7 @@ This filter has a statistics tree rooted at *tls_inspector* with the following s :header: Name, Type, Description :widths: 1, 1, 2 - connection_closed, Counter, Total connections closed client_hello_too_large, Counter, Total unreasonably large Client Hello received - read_error, Counter, Total read errors tls_found, Counter, Total number of times TLS was found tls_not_found, Counter, Total number of times TLS was not found alpn_found, Counter, Total number of times `Application-Layer Protocol Negotiation `_ was successful diff --git a/docs/root/configuration/listeners/stats.rst b/docs/root/configuration/listeners/stats.rst index 4b3ed840237f2..68dd3b8039ee7 100644 --- a/docs/root/configuration/listeners/stats.rst +++ b/docs/root/configuration/listeners/stats.rst @@ -26,6 +26,8 @@ with the following statistics: downstream_pre_cx_active, Gauge, Sockets currently undergoing listener filter processing global_cx_overflow, Counter, Total connections rejected due to enforcement of the global connection limit no_filter_chain_match, Counter, Total connections that didn't match any filter chain + downstream_listener_filter_remote_close, Counter, Total connections closed by remote when peek data for listener filters + downstream_listener_filter_error, Counter, Total numbers of error when peek data for listener filters .. _config_listener_stats_tls: diff --git a/docs/root/version_history/current.rst b/docs/root/version_history/current.rst index 83f235563ead2..81f47d6e7aebb 100644 --- a/docs/root/version_history/current.rst +++ b/docs/root/version_history/current.rst @@ -7,6 +7,7 @@ Incompatible Behavior Changes * sip-proxy: change API by replacing ``own_domain`` with :ref:`local_services `. * tls: set TLS v1.2 as the default minimal version for servers. Users can still explicitly opt-in to 1.0 and 1.1 using :ref:`tls_minimum_protocol_version `. +* tls-inspector: the listener filter tls inspector's stats ``connection_closed`` and ``read_error`` are removed. The new stats are introduced for listener, ``downstream_peek_remote_close`` and ``read_error`` :ref:`listener stats `. Minor Behavior Changes ---------------------- diff --git a/envoy/buffer/buffer.h b/envoy/buffer/buffer.h index ce07170b75bb1..b5233746e3648 100644 --- a/envoy/buffer/buffer.h +++ b/envoy/buffer/buffer.h @@ -34,6 +34,17 @@ struct RawSlice { bool operator!=(const RawSlice& rhs) const { return !(*this == rhs); } }; +/** + * A const raw memory data slice including the location and length. + */ +struct ConstRawSlice { + const void* mem_ = nullptr; + size_t len_ = 0; + + bool operator==(const RawSlice& rhs) const { return mem_ == rhs.mem_ && len_ == rhs.len_; } + bool operator!=(const RawSlice& rhs) const { return !(*this == rhs); } +}; + using RawSliceVector = absl::InlinedVector; /** diff --git a/envoy/network/BUILD b/envoy/network/BUILD index d3f63009b9789..dee05204c2796 100644 --- a/envoy/network/BUILD +++ b/envoy/network/BUILD @@ -87,6 +87,7 @@ envoy_cc_library( hdrs = ["filter.h"], deps = [ ":listen_socket_interface", + ":listener_filter_buffer_interface", ":transport_socket_interface", "//envoy/buffer:buffer_interface", "//envoy/stream_info:stream_info_interface", @@ -148,6 +149,14 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "listener_filter_buffer_interface", + hdrs = ["listener_filter_buffer.h"], + deps = [ + "//envoy/buffer:buffer_interface", + ], +) + envoy_cc_library( name = "transport_socket_interface", hdrs = ["transport_socket.h"], diff --git a/envoy/network/filter.h b/envoy/network/filter.h index a41a76cad45ab..5846324e673d7 100644 --- a/envoy/network/filter.h +++ b/envoy/network/filter.h @@ -4,6 +4,7 @@ #include "envoy/buffer/buffer.h" #include "envoy/network/listen_socket.h" +#include "envoy/network/listener_filter_buffer.h" #include "envoy/network/transport_socket.h" #include "envoy/stream_info/stream_info.h" #include "envoy/upstream/host_description.h" @@ -330,6 +331,20 @@ class ListenerFilter { * @return status used by the filter manager to manage further filter iteration. */ virtual FilterStatus onAccept(ListenerFilterCallbacks& cb) PURE; + + /** + * Called when data read from the connection. If the filter chain doesn't get + * enough data, the filter chain can be stopped, then waiting for more data. + * @param buffer the buffer of data. + * @return status used by the filter manager to manage further filter iteration. + */ + virtual FilterStatus onData(Network::ListenerFilterBuffer& buffer) PURE; + + /** + * Return the size of data the filter want to inspect from the connection. + * @return the size of data inspect from the connection. 0 means filter needn't any data. + */ + virtual size_t maxReadBytes() const PURE; }; using ListenerFilterPtr = std::unique_ptr; diff --git a/envoy/network/listener_filter_buffer.h b/envoy/network/listener_filter_buffer.h new file mode 100644 index 0000000000000..23cbccd4994f9 --- /dev/null +++ b/envoy/network/listener_filter_buffer.h @@ -0,0 +1,33 @@ +#pragma once + +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/common/pure.h" + +namespace Envoy { +namespace Network { + +/** + * Interface for ListenerFilterBuffer + */ +class ListenerFilterBuffer { +public: + virtual ~ListenerFilterBuffer() = default; + + /** + * Return a single const raw slice to the buffer of the data. + * @return a Buffer::ConstRawSlice pointed to raw buffer. + */ + virtual const Buffer::ConstRawSlice rawSlice() const PURE; + + /** + * Drain the data from the beginning of the buffer. + * @param length the length of data to drain. + * @return a bool indicate the drain is successful or not. + */ + virtual bool drain(uint64_t length) PURE; +}; + +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/BUILD b/source/common/network/BUILD index a7972e857c023..fb8114973294a 100644 --- a/source/common/network/BUILD +++ b/source/common/network/BUILD @@ -242,6 +242,17 @@ envoy_cc_library( ], ) +envoy_cc_library( + name = "listener_filter_buffer_lib", + srcs = ["listener_filter_buffer_impl.cc"], + hdrs = ["listener_filter_buffer_impl.h"], + deps = [ + "//envoy/network:io_handle_interface", + "//envoy/network:listener_filter_buffer_interface", + "//source/common/buffer:buffer_lib", + ], +) + envoy_cc_library( name = "listener_lib", srcs = [ diff --git a/source/common/network/listener_filter_buffer_impl.cc b/source/common/network/listener_filter_buffer_impl.cc new file mode 100644 index 0000000000000..42c33c3763872 --- /dev/null +++ b/source/common/network/listener_filter_buffer_impl.cc @@ -0,0 +1,107 @@ +#include "source/common/network/listener_filter_buffer_impl.h" + +#include + +namespace Envoy { +namespace Network { + +ListenerFilterBufferImpl::ListenerFilterBufferImpl(IoHandle& io_handle, + Event::Dispatcher& dispatcher, + ListenerFilterBufferOnCloseCb close_cb, + ListenerFilterBufferOnDataCb on_data_cb, + uint64_t buffer_size) + : io_handle_(io_handle), dispatcher_(dispatcher), on_close_cb_(close_cb), + on_data_cb_(on_data_cb), buffer_(std::make_unique(buffer_size)), + base_(buffer_.get()), buffer_size_(buffer_size) { + // If the buffer_size not greater than 0, it means that doesn't expect any data. + ASSERT(buffer_size > 0); + + io_handle_.initializeFileEvent( + dispatcher_, [this](uint32_t events) { onFileEvent(events); }, + Event::PlatformDefaultTriggerType, Event::FileReadyType::Read); +} + +const Buffer::ConstRawSlice ListenerFilterBufferImpl::rawSlice() const { + Buffer::ConstRawSlice slice; + slice.mem_ = base_; + slice.len_ = data_size_; + return slice; +} + +bool ListenerFilterBufferImpl::drain(uint64_t length) { + if (length == 0) { + return true; + } + + ASSERT(length <= data_size_); + + uint64_t read_size = 0; + while (read_size < length) { + auto result = io_handle_.recv(base_, length - read_size, 0); + ENVOY_LOG(trace, "recv returned: {}", result.return_value_); + + if (!result.ok()) { + // `IoErrorCode::Again` isn't processed here, since + // the data already in the socket buffer. + return false; + } + read_size += result.return_value_; + } + base_ += length; + data_size_ -= length; + return true; +} + +PeekState ListenerFilterBufferImpl::peekFromSocket() { + // Reset buffer base in case of draining changed base. + auto old_base = base_; + base_ = buffer_.get(); + const auto result = io_handle_.recv(base_, buffer_size_, MSG_PEEK); + ENVOY_LOG(trace, "recv returned: {}", result.return_value_); + + if (!result.ok()) { + if (result.err_->getErrorCode() == Api::IoError::IoErrorCode::Again) { + ENVOY_LOG(trace, "recv return try again"); + base_ = old_base; + return PeekState::Again; + } + ENVOY_LOG(debug, "recv failed: {}: {}", static_cast(result.err_->getErrorCode()), + result.err_->getErrorDetails()); + return PeekState::Error; + } + // Remote closed + if (result.return_value_ == 0) { + ENVOY_LOG(debug, "recv failed: remote closed"); + return PeekState::RemoteClose; + } + data_size_ = result.return_value_; + ASSERT(data_size_ <= buffer_size_); + + return PeekState::Done; +} + +void ListenerFilterBufferImpl::resetCapacity(uint64_t size) { + buffer_ = std::make_unique(size); + base_ = buffer_.get(); + buffer_size_ = size; + data_size_ = 0; +} + +void ListenerFilterBufferImpl::activateFileEvent(uint32_t events) { onFileEvent(events); } + +void ListenerFilterBufferImpl::onFileEvent(uint32_t events) { + ENVOY_LOG(trace, "onFileEvent: {}", events); + + auto state = peekFromSocket(); + if (state == PeekState::Done) { + on_data_cb_(*this); + } else if (state == PeekState::Error) { + on_close_cb_(true); + } else if (state == PeekState::RemoteClose) { + on_close_cb_(false); + } + // Did nothing for `Api::IoError::IoErrorCode::Again` +} + +} // namespace Network +} // namespace Envoy diff --git a/source/common/network/listener_filter_buffer_impl.h b/source/common/network/listener_filter_buffer_impl.h new file mode 100644 index 0000000000000..fb02705c077bd --- /dev/null +++ b/source/common/network/listener_filter_buffer_impl.h @@ -0,0 +1,72 @@ +#pragma once + +#include +#include + +#include "envoy/buffer/buffer.h" +#include "envoy/network/io_handle.h" +#include "envoy/network/listener_filter_buffer.h" + +#include "source/common/buffer/buffer_impl.h" + +namespace Envoy { +namespace Network { + +class ListenerFilterBufferImpl; +using ListenerFilterBufferOnCloseCb = std::function; +using ListenerFilterBufferOnDataCb = std::function; + +enum class PeekState { + // Peek data status successful. + Done, + // Need to try again. + Again, + // Error to peek data. + Error, + // Connection closed by remote. + RemoteClose, +}; + +class ListenerFilterBufferImpl : public ListenerFilterBuffer, Logger::Loggable { +public: + ListenerFilterBufferImpl(IoHandle& io_handle, Event::Dispatcher& dispatcher, + ListenerFilterBufferOnCloseCb close_cb, + ListenerFilterBufferOnDataCb on_data_cb, uint64_t buffer_size); + + // ListenerFilterBuffer + const Buffer::ConstRawSlice rawSlice() const override; + bool drain(uint64_t length) override; + + /** + * Trigger the data peek from the socket. + */ + PeekState peekFromSocket(); + + void reset() { io_handle_.resetFileEvents(); } + + void activateFileEvent(uint32_t events); + uint64_t capacity() const { return buffer_size_; } + void resetCapacity(uint64_t size); + +private: + void onFileEvent(uint32_t events); + + IoHandle& io_handle_; + Event::Dispatcher& dispatcher_; + ListenerFilterBufferOnCloseCb on_close_cb_; + ListenerFilterBufferOnDataCb on_data_cb_; + + // The buffer for the data peeked from the socket. + std::unique_ptr buffer_; + // The start of buffer. + uint8_t* base_; + // The size of buffer; + uint64_t buffer_size_; + // The size of valid data. + uint64_t data_size_{0}; +}; + +using ListenerFilterBufferImplPtr = std::unique_ptr; + +} // namespace Network +} // namespace Envoy diff --git a/source/extensions/filters/listener/http_inspector/http_inspector.cc b/source/extensions/filters/listener/http_inspector/http_inspector.cc index f2c864010abc5..71f3f2883d651 100644 --- a/source/extensions/filters/listener/http_inspector/http_inspector.cc +++ b/source/extensions/filters/listener/http_inspector/http_inspector.cc @@ -22,7 +22,6 @@ Config::Config(Stats::Scope& scope) : stats_{ALL_HTTP_INSPECTOR_STATS(POOL_COUNTER_PREFIX(scope, "http_inspector."))} {} const absl::string_view Filter::HTTP2_CONNECTION_PREFACE = "PRI * HTTP/2.0\r\n\r\nSM\r\n\r\n"; -thread_local uint8_t Filter::buf_[Config::MAX_INSPECT_SIZE]; Filter::Filter(const ConfigSharedPtr config) : config_(config) { http_parser_init(&parser_, HTTP_REQUEST); @@ -32,88 +31,41 @@ http_parser_settings Filter::settings_{ nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, }; -Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) { - ENVOY_LOG(debug, "http inspector: new connection accepted"); - - const Network::ConnectionSocket& socket = cb.socket(); - - const absl::string_view transport_protocol = socket.detectedTransportProtocol(); - if (!transport_protocol.empty() && transport_protocol != "raw_buffer") { - ENVOY_LOG(trace, "http inspector: cannot inspect http protocol with transport socket {}", - transport_protocol); - return Network::FilterStatus::Continue; - } - - cb_ = &cb; - const ParseState parse_state = onRead(); +Network::FilterStatus Filter::onData(Network::ListenerFilterBuffer& buffer) { + auto raw_slice = buffer.rawSlice(); + const char* buf = static_cast(raw_slice.mem_); + const auto parse_state = parseHttpHeader(absl::string_view(buf, raw_slice.len_)); switch (parse_state) { case ParseState::Error: + done(false); // As per discussion in https://github.com/envoyproxy/envoy/issues/7864 // we don't add new enum in FilterStatus so we have to signal the caller // the new condition. - cb.socket().close(); + cb_->socket().close(); return Network::FilterStatus::StopIteration; case ParseState::Done: + done(true); return Network::FilterStatus::Continue; case ParseState::Continue: - // do nothing but create the event - cb.socket().ioHandle().initializeFileEvent( - cb.dispatcher(), - [this](uint32_t events) { - ENVOY_LOG(trace, "http inspector event: {}", events); - - const ParseState parse_state = onRead(); - switch (parse_state) { - case ParseState::Error: - cb_->socket().ioHandle().resetFileEvents(); - cb_->continueFilterChain(false); - break; - case ParseState::Done: - cb_->socket().ioHandle().resetFileEvents(); - // Do not skip following listener filters. - cb_->continueFilterChain(true); - break; - case ParseState::Continue: - // do nothing but wait for the next event - break; - } - }, - Event::PlatformDefaultTriggerType, Event::FileReadyType::Read); return Network::FilterStatus::StopIteration; } PANIC_DUE_TO_CORRUPT_ENUM } -ParseState Filter::onRead() { - auto result = cb_->socket().ioHandle().recv(buf_, Config::MAX_INSPECT_SIZE, MSG_PEEK); - ENVOY_LOG(trace, "http inspector: recv: {}", result.return_value_); - if (!result.ok()) { - if (result.err_->getErrorCode() == Api::IoError::IoErrorCode::Again) { - return ParseState::Continue; - } - config_->stats().read_error_.inc(); - return ParseState::Error; - } +Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) { + ENVOY_LOG(debug, "http inspector: new connection accepted"); - // Remote closed - if (result.return_value_ == 0) { - return ParseState::Error; - } + const Network::ConnectionSocket& socket = cb.socket(); - const auto parse_state = - parseHttpHeader(absl::string_view(reinterpret_cast(buf_), result.return_value_)); - switch (parse_state) { - case ParseState::Continue: - // do nothing but wait for the next event - return ParseState::Continue; - case ParseState::Error: - done(false); - return ParseState::Done; - case ParseState::Done: - done(true); - return ParseState::Done; + const absl::string_view transport_protocol = socket.detectedTransportProtocol(); + if (!transport_protocol.empty() && transport_protocol != "raw_buffer") { + ENVOY_LOG(trace, "http inspector: cannot inspect http protocol with transport socket {}", + transport_protocol); + return Network::FilterStatus::Continue; } - PANIC_DUE_TO_CORRUPT_ENUM; + + cb_ = &cb; + return Network::FilterStatus::StopIteration; } ParseState Filter::parseHttpHeader(absl::string_view data) { diff --git a/source/extensions/filters/listener/http_inspector/http_inspector.h b/source/extensions/filters/listener/http_inspector/http_inspector.h index b43422f9e66ed..476c4e3b58917 100644 --- a/source/extensions/filters/listener/http_inspector/http_inspector.h +++ b/source/extensions/filters/listener/http_inspector/http_inspector.h @@ -69,11 +69,13 @@ class Filter : public Network::ListenerFilter, Logger::Loggablestats_.downstream_cx_proxy_proto_error_.inc(); - cb_->continueFilterChain(false); + cb_->socket().ioHandle().close(); + return Network::FilterStatus::StopIteration; + } else if (read_state == ReadOrParseState::TryAgainLater) { + return Network::FilterStatus::StopIteration; } + return Network::FilterStatus::Continue; } -ReadOrParseState Filter::onReadWorker() { +ReadOrParseState Filter::parseBuffer(Network::ListenerFilterBuffer& buffer) { Network::ConnectionSocket& socket = cb_->socket(); // We return if a) we do not yet have the header, b) we have the header but not yet all - // the extension data, or c) a socket error occurred when reading the header or the extension - // data. In cases a) and b) we'll be called again when the socket is ready to read and pick up - // where we left off. + // the extension data. if (!proxy_protocol_header_.has_value()) { - const ReadOrParseState read_header_state = readProxyHeader(socket.ioHandle()); + const ReadOrParseState read_header_state = readProxyHeader(buffer); if (read_header_state != ReadOrParseState::Done) { return read_header_state; } } + + // After parse the header, the extensions size is discovered. Then extend the buffer + // size to receive the extensions. + if (proxy_protocol_header_.value().wholeHeaderLength() > max_proxy_protocol_len_) { + max_proxy_protocol_len_ = proxy_protocol_header_.value().wholeHeaderLength(); + // The expected header size is changed, waiting for more data. + return ReadOrParseState::TryAgainLater; + } + if (proxy_protocol_header_.has_value()) { - const ReadOrParseState read_ext_state = readExtensions(socket.ioHandle()); + const ReadOrParseState read_ext_state = readExtensions(buffer); if (read_ext_state != ReadOrParseState::Done) { return read_ext_state; } @@ -135,13 +139,13 @@ ReadOrParseState Filter::onReadWorker() { proxy_protocol_header_.value().remote_address_); } - // Release the file event so that we do not interfere with the connection read events. - socket.ioHandle().resetFileEvents(); - cb_->continueFilterChain(true); + if (!buffer.drain(proxy_protocol_header_.value().wholeHeaderLength())) { + return ReadOrParseState::Error; + } return ReadOrParseState::Done; } -absl::optional Filter::lenV2Address(char* buf) { +absl::optional Filter::lenV2Address(const char* buf) { const uint8_t proto_family = buf[PROXY_PROTO_V2_SIGNATURE_LEN + 1]; const int ver_cmd = buf[PROXY_PROTO_V2_SIGNATURE_LEN]; size_t len; @@ -165,15 +169,16 @@ absl::optional Filter::lenV2Address(char* buf) { return len; } -bool Filter::parseV2Header(char* buf) { +bool Filter::parseV2Header(const char* buf) { const int ver_cmd = buf[PROXY_PROTO_V2_SIGNATURE_LEN]; uint8_t upper_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 2]; uint8_t lower_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 1]; size_t hdr_addr_len = (upper_byte << 8) + lower_byte; if ((ver_cmd & 0xf) == PROXY_PROTO_V2_LOCAL) { - // This is locally-initiated, e.g. health-check, and should not override remote address - proxy_protocol_header_.emplace(WireHeader{hdr_addr_len}); + // This is locally-initiated, e.g. health-check, and should not override remote address. + // According to the spec, this address length should be zero for local connection. + proxy_protocol_header_.emplace(WireHeader{PROXY_PROTO_V2_HEADER_LEN, hdr_addr_len, 0, 0}); return true; } @@ -191,8 +196,8 @@ bool Filter::parseV2Header(char* buf) { uint16_t src_port; uint16_t dst_port; }); - pp_ipv4_addr* v4; - v4 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); + const pp_ipv4_addr* v4; + v4 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); sockaddr_in ra4, la4; memset(&ra4, 0, sizeof(ra4)); memset(&la4, 0, sizeof(la4)); @@ -204,7 +209,8 @@ bool Filter::parseV2Header(char* buf) { la4.sin_port = v4->dst_port; la4.sin_addr.s_addr = v4->dst_addr; proxy_protocol_header_.emplace( - WireHeader{hdr_addr_len - PROXY_PROTO_V2_ADDR_LEN_INET, Network::Address::IpVersion::v4, + WireHeader{PROXY_PROTO_V2_HEADER_LEN, hdr_addr_len, PROXY_PROTO_V2_ADDR_LEN_INET, + hdr_addr_len - PROXY_PROTO_V2_ADDR_LEN_INET, Network::Address::IpVersion::v4, std::make_shared(&ra4), std::make_shared(&la4)}); return true; @@ -215,8 +221,8 @@ bool Filter::parseV2Header(char* buf) { uint16_t src_port; uint16_t dst_port; }); - pp_ipv6_addr* v6; - v6 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); + const pp_ipv6_addr* v6; + v6 = reinterpret_cast(&buf[PROXY_PROTO_V2_HEADER_LEN]); sockaddr_in6 ra6, la6; memset(&ra6, 0, sizeof(ra6)); memset(&la6, 0, sizeof(la6)); @@ -229,6 +235,7 @@ bool Filter::parseV2Header(char* buf) { safeMemcpy(&(la6.sin6_addr.s6_addr), &(v6->dst_addr)); proxy_protocol_header_.emplace(WireHeader{ + PROXY_PROTO_V2_HEADER_LEN, hdr_addr_len, PROXY_PROTO_V2_ADDR_LEN_INET6, hdr_addr_len - PROXY_PROTO_V2_ADDR_LEN_INET6, Network::Address::IpVersion::v6, std::make_shared(ra6), std::make_shared(la6)}); @@ -240,7 +247,7 @@ bool Filter::parseV2Header(char* buf) { return false; } -bool Filter::parseV1Header(char* buf, size_t len) { +bool Filter::parseV1Header(const char* buf, size_t len) { std::string proxy_line; proxy_line.assign(buf, len); const auto trimmed_proxy_line = StringUtil::rtrim(proxy_line); @@ -277,7 +284,8 @@ bool Filter::parseV1Header(char* buf, size_t len) { return false; } proxy_protocol_header_.emplace( - WireHeader{0, Network::Address::IpVersion::v4, remote_address, local_address}); + WireHeader{len, 0, 0, 0, Network::Address::IpVersion::v4, remote_address, local_address}); + return true; } else if (line_parts[1] == "TCP6") { const Network::Address::InstanceConstSharedPtr remote_address = Network::Utility::parseInternetAddressAndPortNoThrow("[" + std::string{line_parts[2]} + @@ -290,41 +298,17 @@ bool Filter::parseV1Header(char* buf, size_t len) { return false; } proxy_protocol_header_.emplace( - WireHeader{0, Network::Address::IpVersion::v6, remote_address, local_address}); + WireHeader{len, 0, 0, 0, Network::Address::IpVersion::v6, remote_address, local_address}); + return true; } else { ENVOY_LOG(debug, "failed to read proxy protocol"); return false; } } + proxy_protocol_header_.emplace(WireHeader{len, 0, 0, 0}); return true; } -ReadOrParseState Filter::parseExtensions(Network::IoHandle& io_handle, uint8_t* buf, - size_t buf_size, size_t* buf_off) { - // If we ever implement extensions elsewhere, be sure to - // continue to skip and ignore those for LOCAL. - while (proxy_protocol_header_.value().extensions_length_) { - int to_read = std::min(buf_size, proxy_protocol_header_.value().extensions_length_); - buf += (nullptr != buf_off) ? *buf_off : 0; - const auto recv_result = io_handle.recv(buf, to_read, 0); - if (!recv_result.ok()) { - if (recv_result.err_->getErrorCode() == Api::IoError::IoErrorCode::Again) { - return ReadOrParseState::TryAgainLater; - } - ENVOY_LOG(debug, "failed to read proxy protocol (no bytes avail)"); - return ReadOrParseState::Error; - } - - proxy_protocol_header_.value().extensions_length_ -= recv_result.return_value_; - - if (nullptr != buf_off) { - *buf_off += recv_result.return_value_; - } - } - - return ReadOrParseState::Done; -} - /** * @note A TLV is arranged in the following format: * struct pp2_tlv { @@ -335,33 +319,33 @@ ReadOrParseState Filter::parseExtensions(Network::IoHandle& io_handle, uint8_t* * }; * See https://www.haproxy.org/download/2.1/doc/proxy-protocol.txt for details */ -bool Filter::parseTlvs(const std::vector& tlvs) { +bool Filter::parseTlvs(const uint8_t* buf, size_t len) { size_t idx{0}; - while (idx < tlvs.size()) { - const uint8_t tlv_type = tlvs[idx]; + while (idx < len) { + const uint8_t tlv_type = buf[idx]; idx++; - if ((idx + 1) >= tlvs.size()) { + if ((idx + 1) >= len) { ENVOY_LOG(debug, fmt::format("failed to read proxy protocol extension. No bytes for TLV length. " "Extension length is {}, current index is {}, current type is {}.", - tlvs.size(), idx, tlv_type)); + len, idx, tlv_type)); return false; } - const uint8_t tlv_length_upper = tlvs[idx]; - const uint8_t tlv_length_lower = tlvs[idx + 1]; + const uint8_t tlv_length_upper = buf[idx]; + const uint8_t tlv_length_lower = buf[idx + 1]; const size_t tlv_value_length = (tlv_length_upper << 8) + tlv_length_lower; idx += 2; // Get the value. - if ((idx + tlv_value_length - 1) >= tlvs.size()) { + if ((idx + tlv_value_length - 1) >= len) { ENVOY_LOG( debug, fmt::format("failed to read proxy protocol extension. No bytes for TLV value. " "Extension length is {}, current index is {}, current type is {}, current " "value length is {}.", - tlvs.size(), idx, tlv_type, tlv_length_upper)); + len, idx, tlv_type, tlv_length_upper)); return false; } @@ -369,8 +353,7 @@ bool Filter::parseTlvs(const std::vector& tlvs) { auto key_value_pair = config_->isTlvTypeNeeded(tlv_type); if (nullptr != key_value_pair) { ProtobufWkt::Value metadata_value; - metadata_value.set_string_value(reinterpret_cast(tlvs.data() + idx), - tlv_value_length); + metadata_value.set_string_value(reinterpret_cast(buf + idx), tlv_value_length); std::string metadata_key = key_value_pair->metadata_namespace().empty() ? "envoy.filters.listener.proxy_protocol" @@ -385,164 +368,108 @@ bool Filter::parseTlvs(const std::vector& tlvs) { } idx += tlv_value_length; - ASSERT(idx <= tlvs.size()); + ASSERT(idx <= len); } return true; } -ReadOrParseState Filter::readExtensions(Network::IoHandle& io_handle) { - // Parse and discard the extensions if this is a local command or there's no TLV needs to be saved - // to metadata. - if (proxy_protocol_header_.value().local_command_ || 0 == config_->numberOfNeededTlvTypes()) { - // buf_ is no longer in use so we re-use it to read/discard. - return parseExtensions(io_handle, reinterpret_cast(buf_), sizeof(buf_), nullptr); - } - - // Initialize the buf_tlv_ only when we need to read the TLVs. - if (buf_tlv_.empty()) { - buf_tlv_.resize(proxy_protocol_header_.value().extensions_length_); +ReadOrParseState Filter::readExtensions(Network::ListenerFilterBuffer& buffer) { + auto raw_slice = buffer.rawSlice(); + // waiting for more data if there is no enough data for extensions. + if (raw_slice.len_ < (proxy_protocol_header_.value().wholeHeaderLength())) { + return ReadOrParseState::TryAgainLater; } - // Parse until we have all the TLVs in buf_tlv. - const ReadOrParseState parse_extensions_state = - parseExtensions(io_handle, buf_tlv_.data(), buf_tlv_.size(), &buf_tlv_off_); - if (parse_extensions_state != ReadOrParseState::Done) { - return parse_extensions_state; + if (proxy_protocol_header_.value().local_command_ || 0 == config_->numberOfNeededTlvTypes()) { + // Ignores the extensions if this is a local command or there's no TLV needs to be saved + // to metadata. Those will drained from the buffer in the end. + return ReadOrParseState::Done; } - if (!parseTlvs(buf_tlv_)) { + const uint8_t* buf = static_cast(raw_slice.mem_) + + proxy_protocol_header_.value().headerLengthWithoutExtension(); + if (!parseTlvs(buf, proxy_protocol_header_.value().extensions_length_)) { return ReadOrParseState::Error; } return ReadOrParseState::Done; } -ReadOrParseState Filter::readProxyHeader(Network::IoHandle& io_handle) { - while (buf_off_ < MAX_PROXY_PROTO_LEN_V2) { - const auto result = - io_handle.recv(buf_ + buf_off_, MAX_PROXY_PROTO_LEN_V2 - buf_off_, MSG_PEEK); - - if (!result.ok()) { - if (result.err_->getErrorCode() == Api::IoError::IoErrorCode::Again) { - return ReadOrParseState::TryAgainLater; - } - ENVOY_LOG(debug, "failed to read proxy protocol (no bytes read)"); +ReadOrParseState Filter::readProxyHeader(Network::ListenerFilterBuffer& buffer) { + auto raw_slice = buffer.rawSlice(); + const char* buf = static_cast(raw_slice.mem_); + + if (raw_slice.len_ >= PROXY_PROTO_V2_HEADER_LEN) { + const char* sig = PROXY_PROTO_V2_SIGNATURE; + if (!memcmp(buf, sig, PROXY_PROTO_V2_SIGNATURE_LEN)) { + header_version_ = V2; + } else if (memcmp(buf, PROXY_PROTO_V1_SIGNATURE, PROXY_PROTO_V1_SIGNATURE_LEN)) { + // It is not v2, and can't be v1, so no sense hanging around: it is invalid + ENVOY_LOG(debug, "failed to read proxy protocol (exceed max v1 header len)"); return ReadOrParseState::Error; } - ssize_t nread = result.return_value_; + } - if (nread < 1) { - ENVOY_LOG(debug, "failed to read proxy protocol (no bytes read)"); + if (header_version_ == V2) { + const int ver_cmd = buf[PROXY_PROTO_V2_SIGNATURE_LEN]; + if (((ver_cmd & 0xf0) >> 4) != PROXY_PROTO_V2_VERSION) { + ENVOY_LOG(debug, "Unsupported V2 proxy protocol version"); return ReadOrParseState::Error; } - - if (buf_off_ + nread >= PROXY_PROTO_V2_HEADER_LEN) { - const char* sig = PROXY_PROTO_V2_SIGNATURE; - if (!memcmp(buf_, sig, PROXY_PROTO_V2_SIGNATURE_LEN)) { - header_version_ = V2; - } else if (memcmp(buf_, PROXY_PROTO_V1_SIGNATURE, PROXY_PROTO_V1_SIGNATURE_LEN)) { - // It is not v2, and can't be v1, so no sense hanging around: it is invalid - ENVOY_LOG(debug, "failed to read proxy protocol (exceed max v1 header len)"); - return ReadOrParseState::Error; - } + absl::optional addr_len_opt = lenV2Address(buf); + if (!addr_len_opt.has_value()) { + return ReadOrParseState::Error; } - - if (header_version_ == V2) { - const int ver_cmd = buf_[PROXY_PROTO_V2_SIGNATURE_LEN]; - if (((ver_cmd & 0xf0) >> 4) != PROXY_PROTO_V2_VERSION) { - ENVOY_LOG(debug, "Unsupported V2 proxy protocol version"); - return ReadOrParseState::Error; - } - if (buf_off_ < PROXY_PROTO_V2_HEADER_LEN) { - ssize_t exp = PROXY_PROTO_V2_HEADER_LEN - buf_off_; - const auto read_result = io_handle.recv(buf_ + buf_off_, exp, 0); - if (!read_result.ok() || read_result.return_value_ != uint64_t(exp)) { - ENVOY_LOG(debug, "failed to read proxy protocol (remote closed)"); - return ReadOrParseState::Error; - } - buf_off_ += read_result.return_value_; - nread -= read_result.return_value_; - } - absl::optional addr_len_opt = lenV2Address(buf_); - if (!addr_len_opt.has_value()) { - return ReadOrParseState::Error; - } - ssize_t addr_len = addr_len_opt.value(); - uint8_t upper_byte = buf_[PROXY_PROTO_V2_HEADER_LEN - 2]; - uint8_t lower_byte = buf_[PROXY_PROTO_V2_HEADER_LEN - 1]; - ssize_t hdr_addr_len = (upper_byte << 8) + lower_byte; - if (hdr_addr_len < addr_len) { - ENVOY_LOG(debug, "failed to read proxy protocol (insufficient data)"); + ssize_t addr_len = addr_len_opt.value(); + uint8_t upper_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 2]; + uint8_t lower_byte = buf[PROXY_PROTO_V2_HEADER_LEN - 1]; + ssize_t hdr_addr_len = (upper_byte << 8) + lower_byte; + if (hdr_addr_len < addr_len) { + ENVOY_LOG(debug, + "incorrect address length, address length = {}, the expected address length = {}", + hdr_addr_len, addr_len); + return ReadOrParseState::Error; + } + // waiting for more data if there is no enough data for address. + if (raw_slice.len_ >= static_cast(PROXY_PROTO_V2_HEADER_LEN + addr_len)) { + // The TLV remain, they are parsed in `parseTlvs()` which is called from the + // parent (if needed). + if (parseV2Header(buf)) { + return ReadOrParseState::Done; + } else { return ReadOrParseState::Error; } - if (ssize_t(buf_off_) + nread >= PROXY_PROTO_V2_HEADER_LEN + addr_len) { - ssize_t missing = (PROXY_PROTO_V2_HEADER_LEN + addr_len) - buf_off_; - const auto read_result = io_handle.recv(buf_ + buf_off_, missing, 0); - if (!read_result.ok() || read_result.return_value_ != uint64_t(missing)) { - ENVOY_LOG(debug, "failed to read proxy protocol (remote closed)"); - return ReadOrParseState::Error; - } - buf_off_ += read_result.return_value_; - // The TLV remain, they are read/discard in parseExtensions() which is called from the - // parent (if needed). - if (parseV2Header(buf_)) { - return ReadOrParseState::Done; - } else { - return ReadOrParseState::Error; - } - } else if (nread != 0) { - const auto result = io_handle.recv(buf_ + buf_off_, nread, 0); - nread = result.return_value_; - if (!result.ok()) { - ENVOY_LOG(debug, "failed to read proxy protocol (remote closed)"); - return ReadOrParseState::Error; - } - buf_off_ += nread; - } - } else { - // continue searching buf_ from where we left off - for (; search_index_ < buf_off_ + nread; search_index_++) { - if (buf_[search_index_] == '\n' && buf_[search_index_ - 1] == '\r') { - if (search_index_ == 1) { - // This could be the binary protocol. It cannot be the ascii protocol - header_version_ = InProgress; - } else { - header_version_ = V1; - search_index_++; - } + } + } else { + // continue searching buffer from where we left off + for (; search_index_ < raw_slice.len_; search_index_++) { + if (buf[search_index_] == '\n' && buf[search_index_ - 1] == '\r') { + if (search_index_ == 1) { + // There is not enough data to determine if it contains the v2 protocol signature, so wait + // for more data. break; + } else { + header_version_ = V1; + search_index_++; } + break; } + } - // If we bailed on the first char, we might be v2, but are for sure not v1. Thus we - // can read up to min(PROXY_PROTO_V2_HEADER_LEN, bytes_avail). If we bailed after first - // char, but before we hit \r\n, read up to search_index_. We're asking only for - // bytes we've already seen so there should be no block or fail - size_t ntoread; - if (header_version_ == InProgress) { - ntoread = nread; - } else { - ntoread = search_index_ - buf_off_; - } - - const auto result = io_handle.recv(buf_ + buf_off_, ntoread, 0); - nread = result.return_value_; - ASSERT(result.ok() && size_t(nread) == ntoread); - - buf_off_ += nread; + if (search_index_ > MAX_PROXY_PROTO_LEN_V1) { + return ReadOrParseState::Error; + } - if (header_version_ == V1) { - if (parseV1Header(buf_, buf_off_)) { - return ReadOrParseState::Done; - } else { - return ReadOrParseState::Error; - } + if (header_version_ == V1) { + if (parseV1Header(buf, search_index_)) { + return ReadOrParseState::Done; + } else { + return ReadOrParseState::Error; } } } - ENVOY_LOG(debug, "failed to read proxy protocol (exceed max v2 header len)"); - return ReadOrParseState::Error; + return ReadOrParseState::TryAgainLater; } } // namespace ProxyProtocol diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h index 469c838f9b83f..86d292615cc3f 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol.h @@ -66,7 +66,7 @@ class Config : public Logger::Loggable { using ConfigSharedPtr = std::shared_ptr; -enum ProxyProtocolVersion { Unknown = -1, InProgress = -2, V1 = 1, V2 = 2 }; +enum ProxyProtocolVersion { Unknown = 0, V1 = 1, V2 = 2 }; enum class ReadOrParseState { Done, TryAgainLater, Error }; @@ -86,6 +86,8 @@ class Filter : public Network::ListenerFilter, Logger::Loggable& tlvs); - ReadOrParseState readExtensions(Network::IoHandle& io_handle); + bool parseTlvs(const uint8_t* buf, size_t len); + ReadOrParseState readExtensions(Network::ListenerFilterBuffer& buffer); /** * Given a char * & len, parse the header as per spec. * @return bool true if parsing succeeded, false if parsing failed. */ - bool parseV1Header(char* buf, size_t len); - bool parseV2Header(char* buf); - absl::optional lenV2Address(char* buf); + bool parseV1Header(const char* buf, size_t len); + bool parseV2Header(const char* buf); + absl::optional lenV2Address(const char* buf); Network::ListenerFilterCallbacks* cb_{}; - // The offset in buf_ that has been fully read - size_t buf_off_{}; - // The index in buf_ where the search for '\r\n' should continue from size_t search_index_{1}; ProxyProtocolVersion header_version_{Unknown}; - // Stores the portion of the first line that has been read so far. - char buf_[MAX_PROXY_PROTO_LEN_V2]; - - /** - * Store the extension TLVs if they need to be read. - */ - std::vector buf_tlv_; - - /** - * The index in buf_tlv_ that has been fully read. - */ - size_t buf_tlv_off_{}; - ConfigSharedPtr config_; absl::optional proxy_protocol_header_; + size_t max_proxy_protocol_len_{MAX_PROXY_PROTO_LEN_V2}; }; } // namespace ProxyProtocol diff --git a/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h b/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h index f72258ce0ff55..446a1bbb6d73a 100644 --- a/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h +++ b/source/extensions/filters/listener/proxy_protocol/proxy_protocol_header.h @@ -10,17 +10,36 @@ namespace ListenerFilters { namespace ProxyProtocol { struct WireHeader { - WireHeader(size_t extensions_length) - : extensions_length_(extensions_length), protocol_version_(Network::Address::IpVersion::v4), - remote_address_(nullptr), local_address_(nullptr), local_command_(true) {} - WireHeader(size_t extensions_length, Network::Address::IpVersion protocol_version, + WireHeader(size_t header_length, size_t header_addr_length, size_t addr_lengh, + size_t extensions_length) + : header_length_(header_length), header_addr_length_(header_addr_length), + addr_lengh_(addr_lengh), extensions_length_(extensions_length), + protocol_version_(Network::Address::IpVersion::v4), remote_address_(nullptr), + local_address_(nullptr), local_command_(true) {} + WireHeader(size_t header_length, size_t header_addr_length, size_t addr_lengh, + size_t extensions_length, Network::Address::IpVersion protocol_version, Network::Address::InstanceConstSharedPtr remote_address, Network::Address::InstanceConstSharedPtr local_address) - : extensions_length_(extensions_length), protocol_version_(protocol_version), - remote_address_(remote_address), local_address_(local_address), local_command_(false) { + : header_length_(header_length), header_addr_length_(header_addr_length), + addr_lengh_(addr_lengh), extensions_length_(extensions_length), + protocol_version_(protocol_version), remote_address_(remote_address), + local_address_(local_address), local_command_(false) { ASSERT(extensions_length_ <= 65535); } + + size_t wholeHeaderLength() { return header_length_ + header_addr_length_; } + + size_t headerLengthWithoutExtension() { return header_length_ + addr_lengh_; } + + // For v1, this is whole length of the header util the end `\r\n`. + // For v2, this is PROXY_PROTO_V2_HEADER_LEN, without address and extensions; + size_t header_length_; + // For v1, this is zero. For v2, this is the length of address and extensions; + size_t header_addr_length_; + // For v1, this is zero. For v2, this is the length of address. + size_t addr_lengh_; + // For v1, this is zero, For v2, this is the length of extensions. size_t extensions_length_; const Network::Address::IpVersion protocol_version_; const Network::Address::InstanceConstSharedPtr remote_address_; diff --git a/source/extensions/filters/listener/tls_inspector/BUILD b/source/extensions/filters/listener/tls_inspector/BUILD index ec2a3e0135644..341dfef29f5f5 100644 --- a/source/extensions/filters/listener/tls_inspector/BUILD +++ b/source/extensions/filters/listener/tls_inspector/BUILD @@ -27,6 +27,7 @@ envoy_cc_library( "//envoy/network:filter_interface", "//envoy/network:listen_socket_interface", "//source/common/api:os_sys_calls_lib", + "//source/common/buffer:buffer_lib", "//source/common/common:assert_lib", "//source/common/common:hex_lib", "//source/common/common:minimal_logger_lib", diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc index 4ce78ceffaae5..f42171d222129 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.cc +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.cc @@ -11,6 +11,7 @@ #include "envoy/stats/scope.h" #include "source/common/api/os_sys_calls_impl.h" +#include "source/common/buffer/buffer_impl.h" #include "source/common/common/assert.h" #include "source/common/common/hex.h" #include "source/common/protobuf/utility.h" @@ -75,54 +76,15 @@ Config::Config( bssl::UniquePtr Config::newSsl() { return bssl::UniquePtr{SSL_new(ssl_ctx_.get())}; } -thread_local uint8_t Filter::buf_[Config::TLS_MAX_CLIENT_HELLO]; - Filter::Filter(const ConfigSharedPtr config) : config_(config), ssl_(config_->newSsl()) { - RELEASE_ASSERT(sizeof(buf_) >= config_->maxClientHelloSize(), ""); - SSL_set_app_data(ssl_.get(), this); SSL_set_accept_state(ssl_.get()); } Network::FilterStatus Filter::onAccept(Network::ListenerFilterCallbacks& cb) { ENVOY_LOG(debug, "tls inspector: new connection accepted"); - Network::ConnectionSocket& socket = cb.socket(); cb_ = &cb; - ParseState parse_state = onRead(); - switch (parse_state) { - case ParseState::Error: - // As per discussion in https://github.com/envoyproxy/envoy/issues/7864 - // we don't add new enum in FilterStatus so we have to signal the caller - // the new condition. - cb.socket().close(); - return Network::FilterStatus::StopIteration; - case ParseState::Done: - return Network::FilterStatus::Continue; - case ParseState::Continue: - // do nothing but create the event - socket.ioHandle().initializeFileEvent( - cb.dispatcher(), - [this](uint32_t events) { - ASSERT(events == Event::FileReadyType::Read); - ParseState parse_state = onRead(); - switch (parse_state) { - case ParseState::Error: - done(false); - break; - case ParseState::Done: - done(true); - break; - case ParseState::Continue: - // do nothing but wait for the next event - break; - } - }, - Event::PlatformDefaultTriggerType, Event::FileReadyType::Read); - return Network::FilterStatus::StopIteration; - } - - IS_ENVOY_BUG("unexpected tcp filter parse_state"); return Network::FilterStatus::StopIteration; } @@ -158,50 +120,31 @@ void Filter::onServername(absl::string_view name) { clienthello_success_ = true; } -ParseState Filter::onRead() { - // This receive code is somewhat complicated, because it must be done as a MSG_PEEK because - // there is no way for a listener-filter to pass payload data to the ConnectionImpl and filters - // that get created later. - // - // We request from the file descriptor to get events every time new data is available, - // even if previous data has not been read, which is always the case due to MSG_PEEK. When - // the TlsInspector completes and passes the socket along, a new FileEvent is created for the - // socket, so that new event is immediately signaled as readable because it is new and the socket - // is readable, even though no new events have occurred. - // - // TODO(ggreenway): write an integration test to ensure the events work as expected on all - // platforms. - const auto result = cb_->socket().ioHandle().recv(buf_, config_->maxClientHelloSize(), MSG_PEEK); - ENVOY_LOG(trace, "tls inspector: recv: {}", result.return_value_); - - if (!result.ok()) { - if (result.err_->getErrorCode() == Api::IoError::IoErrorCode::Again) { - return ParseState::Continue; - } - config_->stats().read_error_.inc(); - return ParseState::Error; - } - - if (result.return_value_ == 0) { - config_->stats().connection_closed_.inc(); - return ParseState::Error; - } +Network::FilterStatus Filter::onData(Network::ListenerFilterBuffer& buffer) { + auto raw_slice = buffer.rawSlice(); + ENVOY_LOG(trace, "tls inspector: recv: {}", raw_slice.len_); // Because we're doing a MSG_PEEK, data we've seen before gets returned every time, so // skip over what we've already processed. - if (static_cast(result.return_value_) > read_) { - const uint8_t* data = buf_ + read_; - const size_t len = result.return_value_ - read_; - read_ = result.return_value_; - return parseClientHello(data, len); + if (static_cast(raw_slice.len_) > read_) { + const uint8_t* data = static_cast(raw_slice.mem_) + read_; + const size_t len = raw_slice.len_ - read_; + read_ = raw_slice.len_; + ParseState parse_state = parseClientHello(data, len); + switch (parse_state) { + case ParseState::Error: + cb_->socket().ioHandle().close(); + return Network::FilterStatus::StopIteration; + case ParseState::Done: + // Finish the inspect. + return Network::FilterStatus::Continue; + case ParseState::Continue: + // Do nothing but wait for the next event. + return Network::FilterStatus::StopIteration; + } + IS_ENVOY_BUG("unexpected tcp filter parse_state"); } - return ParseState::Continue; -} - -void Filter::done(bool success) { - ENVOY_LOG(trace, "tls inspector: done: {}", success); - cb_->socket().ioHandle().resetFileEvents(); - cb_->continueFilterChain(success); + return Network::FilterStatus::StopIteration; } ParseState Filter::parseClientHello(const void* data, size_t len) { diff --git a/source/extensions/filters/listener/tls_inspector/tls_inspector.h b/source/extensions/filters/listener/tls_inspector/tls_inspector.h index 0e32ef2517c95..cea9f63204929 100644 --- a/source/extensions/filters/listener/tls_inspector/tls_inspector.h +++ b/source/extensions/filters/listener/tls_inspector/tls_inspector.h @@ -20,9 +20,7 @@ namespace TlsInspector { * All stats for the TLS inspector. @see stats_macros.h */ #define ALL_TLS_INSPECTOR_STATS(COUNTER) \ - COUNTER(connection_closed) \ COUNTER(client_hello_too_large) \ - COUNTER(read_error) \ COUNTER(tls_found) \ COUNTER(tls_not_found) \ COUNTER(alpn_found) \ @@ -81,11 +79,12 @@ class Filter : public Network::ListenerFilter, Logger::LoggablemaxClientHelloSize(); } private: ParseState parseClientHello(const void* data, size_t len); ParseState onRead(); - void done(bool success); void onALPN(const unsigned char* data, unsigned int len); void onServername(absl::string_view name); void createJA3Hash(const SSL_CLIENT_HELLO* ssl_client_hello); @@ -98,8 +97,6 @@ class Filter : public Network::ListenerFilter, Logger::Loggable( + socket_->ioHandle(), listener_.dispatcher(), + [this](bool error) { + socket_->ioHandle().close(); + if (error) { + listener_.stats_.downstream_listener_filter_error_.inc(); + } else { + listener_.stats_.downstream_listener_filter_remote_close_.inc(); + } + continueFilterChain(false); + }, + [this](Network::ListenerFilterBufferImpl& filter_buffer) { + ASSERT((*iter_)->maxReadBytes() != 0); + Network::FilterStatus status = (*iter_)->onData(filter_buffer); + if (status == Network::FilterStatus::StopIteration) { + if (socket_->ioHandle().isOpen()) { + // The listener filter should not wait for more data when it has already received + // all the data it requested. + ASSERT(filter_buffer.rawSlice().len_ < (*iter_)->maxReadBytes()); + // Check if the maxReadBytes is changed or not. If change, + // reset the buffer capacity. + if ((*iter_)->maxReadBytes() > filter_buffer.capacity()) { + filter_buffer.resetCapacity((*iter_)->maxReadBytes()); + // Activate `Read` event manually in case the data already + // available in the socket buffer. + filter_buffer.activateFileEvent(Event::FileReadyType::Read); + } + } + return; + } + continueFilterChain(true); + }, + (*iter_)->maxReadBytes()); +} + void ActiveTcpSocket::continueFilterChain(bool success) { if (success) { bool no_error = true; @@ -85,11 +121,33 @@ void ActiveTcpSocket::continueFilterChain(bool success) { // The filter is responsible for calling us again at a later time to continue the filter // chain from the next filter. if (!socket().ioHandle().isOpen()) { - // break the loop but should not create new connection + // Break the loop but should not create new connection. no_error = false; break; } else { - // Blocking at the filter but no error + // If the listener maxReadBytes() is 0, then it shouldn't return + // `FilterStatus::StopIteration` from `onAccept` to wait for more data. + ASSERT((*iter_)->maxReadBytes() != 0); + if (listener_filter_buffer_ == nullptr) { + if ((*iter_)->maxReadBytes() > 0) { + createListenerFilterBuffer(); + } + } else { + // If the current filter expect more data than previous filters, then + // increase the filter buffer's capacity. + if (listener_filter_buffer_->capacity() < (*iter_)->maxReadBytes()) { + listener_filter_buffer_->resetCapacity((*iter_)->maxReadBytes()); + } + } + if (listener_filter_buffer_ != nullptr) { + // There are two cases for activate event manually: One is + // the data is already available when connect, activate the read event to peek + // data from the socket . Another one is the data already + // peeked into the buffer when previous filter processing the data, then activate the + // read event to trigger the current filter callback to process the data. + listener_filter_buffer_->activateFileEvent(Event::FileReadyType::Read); + } + // Waiting for more data. return; } } @@ -142,7 +200,9 @@ void ActiveTcpSocket::newConnection() { } // Reset the file events which are registered by listener filter. // reference https://github.com/envoyproxy/envoy/issues/8925. - socket_->ioHandle().resetFileEvents(); + if (listener_filter_buffer_ != nullptr) { + listener_filter_buffer_->reset(); + } accept_filters_.clear(); // Create a new connection on this listener. listener_.newConnection(std::move(socket_), std::move(stream_info_)); diff --git a/source/server/active_tcp_socket.h b/source/server/active_tcp_socket.h index 9900bbd66b883..8e46aa7eba573 100644 --- a/source/server/active_tcp_socket.h +++ b/source/server/active_tcp_socket.h @@ -13,6 +13,7 @@ #include "envoy/network/listener.h" #include "source/common/common/linked_object.h" +#include "source/common/network/listener_filter_buffer_impl.h" #include "source/server/active_listener_base.h" namespace Envoy { @@ -48,6 +49,13 @@ struct ActiveTcpSocket : public Network::ListenerFilterManager, } return listener_filter_->onAccept(cb); } + + Network::FilterStatus onData(Network::ListenerFilterBuffer& buffer) override { + return listener_filter_->onData(buffer); + } + + size_t maxReadBytes() const override { return listener_filter_->maxReadBytes(); } + /** * Check if this filter filter should be disabled on the incoming socket. * @param cb the callbacks the filter instance can use to communicate with the filter chain. @@ -87,6 +95,8 @@ struct ActiveTcpSocket : public Network::ListenerFilterManager, StreamInfo::FilterState& filterState() override { return *stream_info_->filterState().get(); } + void createListenerFilterBuffer(); + // The owner of this ActiveTcpSocket. ActiveStreamListenerBase& listener_; Network::ConnectionSocketPtr socket_; @@ -96,6 +106,8 @@ struct ActiveTcpSocket : public Network::ListenerFilterManager, Event::TimerPtr timer_; std::unique_ptr stream_info_; bool connected_{false}; + + Network::ListenerFilterBufferImplPtr listener_filter_buffer_; }; } // namespace Server diff --git a/test/common/network/BUILD b/test/common/network/BUILD index 39ca99aeeebd8..36163df959d97 100644 --- a/test/common/network/BUILD +++ b/test/common/network/BUILD @@ -6,6 +6,7 @@ load( "envoy_cc_test", "envoy_cc_test_library", "envoy_package", + "envoy_proto_library", ) licenses(["notice"]) # Apache 2 @@ -163,6 +164,16 @@ envoy_cc_test( ], ) +envoy_cc_test( + name = "listener_filter_buffer_impl_test", + srcs = ["listener_filter_buffer_impl_test.cc"], + deps = [ + "//source/common/network:listener_filter_buffer_lib", + "//test/mocks/event:event_mocks", + "//test/mocks/network:io_handle_mocks", + ], +) + envoy_cc_test( name = "listener_impl_test", srcs = ["listener_impl_test.cc"], @@ -461,3 +472,21 @@ envoy_cc_test( "//test/mocks/network:network_mocks", ], ) + +envoy_proto_library( + name = "listener_filter_buffer_fuzz_proto", + srcs = ["listener_filter_buffer_fuzz.proto"], +) + +envoy_cc_fuzz_test( + name = "listener_filter_buffer_fuzz_test", + srcs = ["listener_filter_buffer_fuzz_test.cc"], + corpus = "listener_filter_buffer_corpus", + deps = [ + ":listener_filter_buffer_fuzz_proto_cc_proto", + "//source/common/network:listener_filter_buffer_lib", + "//test/fuzz:utility_lib", + "//test/mocks/event:event_mocks", + "//test/mocks/network:io_handle_mocks", + ], +) diff --git a/test/common/network/listener_filter_buffer_corpus/basic b/test/common/network/listener_filter_buffer_corpus/basic new file mode 100644 index 0000000000000..e122c7e244a22 --- /dev/null +++ b/test/common/network/listener_filter_buffer_corpus/basic @@ -0,0 +1,25 @@ +max_bytes_read: 10 +actions { + peek_from_socket: 3 +} +actions { + drain: 1 +} +actions { + peek_from_socket: 5 +} +actions { + reset_capacity: 20 +} +actions { + peek_from_socket: 15 +} +actions { + reset_capacity: 10 +} +actions { + peek_from_socket: 5 +} +actions { + drain: 5 +} \ No newline at end of file diff --git a/test/common/network/listener_filter_buffer_fuzz.proto b/test/common/network/listener_filter_buffer_fuzz.proto new file mode 100644 index 0000000000000..f8ed1a7c68ead --- /dev/null +++ b/test/common/network/listener_filter_buffer_fuzz.proto @@ -0,0 +1,16 @@ +syntax = "proto3"; + +package test.common.network; + +message Action { + oneof action_selector { + uint64 readable = 1; + uint64 drain = 2; + uint64 resetCapacity = 3; + } +} + +message ListenerFilterBufferFuzzTestCase { + uint64 max_bytes_read = 1; + repeated Action actions = 2; +} diff --git a/test/common/network/listener_filter_buffer_fuzz_test.cc b/test/common/network/listener_filter_buffer_fuzz_test.cc new file mode 100644 index 0000000000000..7d0b9fb3cb87a --- /dev/null +++ b/test/common/network/listener_filter_buffer_fuzz_test.cc @@ -0,0 +1,136 @@ +#include "source/common/common/assert.h" +#include "source/common/common/logger.h" +#include "source/common/network/listener_filter_buffer_impl.h" + +#include "test/common/network/listener_filter_buffer_fuzz.pb.h" +#include "test/fuzz/fuzz_runner.h" +#include "test/fuzz/utility.h" +#include "test/mocks/event/mocks.h" +#include "test/mocks/network/io_handle.h" + +#include "gtest/gtest.h" + +using testing::_; +using testing::ByMove; +using testing::Return; +using testing::SaveArg; + +namespace Envoy { +namespace Network { +namespace { + +// The max size of the listener filter buffer. +constexpr uint32_t max_buffer_size = 16 * 1024; +// The max size of available data on the socket. It can be large than +// buffer size, but we won't peek those extra data. +constexpr uint32_t max_readable_size = max_buffer_size + 1024; + +class ListenerFilterBufferFuzzer { +public: + void fuzz(const test::common::network::ListenerFilterBufferFuzzTestCase& input) { + // Ensure the buffer is not exceed the limit we set. + auto max_bytes_read = input.max_bytes_read() % max_buffer_size; + // There won't be any case the max size of buffer is 0. + if (max_bytes_read == 0) { + return; + } + + EXPECT_CALL(io_handle_, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback_)); + + // Use the on_data callback to verify the data. + auto on_data_cb = [&](ListenerFilterBuffer& buffer) { + auto raw_slice = buffer.rawSlice(); + std::string data(reinterpret_cast(raw_slice.mem_), raw_slice.len_); + // The available data may be more than the buffer size, also, the buffer size + // can be reduced by drain. + FUZZ_ASSERT(data == available_data_.substr(0, max_bytes_read - drained_size_)); + }; + auto listener_buffer = std::make_unique( + io_handle_, dispatcher_, [&](bool) {}, on_data_cb, max_bytes_read); + + for (auto i = 0; i < input.actions().size(); i++) { + const char insert_value = 'a' + i % 26; + + switch (input.actions(i).action_selector_case()) { + case test::common::network::Action::kReadable: { + // Generate the available data, and ensure it is under the max_readable_size. + auto append_data_size = + input.actions(i).readable() % (max_readable_size - available_data_.size()); + // If the available is 0, then emulate an `EAGAIN`. + if (append_data_size == 0) { + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + 0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(), + IoSocketError::deleteIoError))))); + } else { + available_data_.insert(available_data_.end(), append_data_size, insert_value); + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + auto copy_size = std::min(length, available_data_.size()); + ::memcpy(buffer, available_data_.data(), copy_size); + return Api::IoCallUint64Result(copy_size, + Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + drained_size_ = 0; + } + // Trigger the peek by event. + file_event_callback_(Event::FileReadyType::Read); + break; + } + case test::common::network::Action::kDrain: { + // The drain method only support drain size less than the buffer size. + auto drain_size = std::min(input.actions(i).drain(), listener_buffer->rawSlice().len_); + if (drain_size != 0) { + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(0, flags); + EXPECT_EQ(drain_size, length); + ::memcpy(buffer, available_data_.data(), drain_size); + available_data_ = available_data_.substr(drain_size); + return Api::IoCallUint64Result(drain_size, + Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + } + drained_size_ += drain_size; + listener_buffer->drain(drain_size); + // Reuse the on_data callback to validate the buffer data. + on_data_cb(*listener_buffer); + break; + } + case test::common::network::Action::kResetCapacity: { + auto capacity_size = input.actions(i).drain() % max_buffer_size; + if (capacity_size == 0) { + break; + } + listener_buffer->resetCapacity(capacity_size); + EXPECT_EQ(capacity_size, listener_buffer->capacity()); + max_bytes_read = capacity_size; + drained_size_ = 0; + available_data_.clear(); + EXPECT_EQ(listener_buffer->rawSlice().len_, 0); + break; + } + default: + break; + } + } + } + +private: + Network::MockIoHandle io_handle_; + Event::MockDispatcher dispatcher_; + Event::FileReadyCb file_event_callback_; + std::string available_data_; + // The size drained by the test. This is used to calculate the current buffer size. + uint64_t drained_size_{0}; +}; + +DEFINE_PROTO_FUZZER(const test::common::network::ListenerFilterBufferFuzzTestCase& input) { + auto fuzzer = ListenerFilterBufferFuzzer(); + fuzzer.fuzz(input); +} + +} // namespace +} // namespace Network +} // namespace Envoy diff --git a/test/common/network/listener_filter_buffer_impl_test.cc b/test/common/network/listener_filter_buffer_impl_test.cc new file mode 100644 index 0000000000000..63522edbe37c6 --- /dev/null +++ b/test/common/network/listener_filter_buffer_impl_test.cc @@ -0,0 +1,253 @@ +#include "envoy/api/io_error.h" + +#include "source/common/network/listener_filter_buffer_impl.h" + +#include "test/mocks/event/mocks.h" +#include "test/mocks/network/io_handle.h" + +#include "gtest/gtest.h" + +using testing::_; +using testing::ByMove; +using testing::Return; +using testing::SaveArg; + +namespace Envoy { +namespace Network { +namespace { + +class ListenerFilterBufferImplTest : public testing::Test { +public: + void initialize() { + EXPECT_CALL(io_handle_, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback_)); + + listener_buffer_ = std::make_unique( + io_handle_, dispatcher_, + [&](bool error) { + if (on_close_cb_) { + on_close_cb_(error); + } + }, + [&](ListenerFilterBufferImpl& filter_buffer) { + if (on_data_cb_) { + on_data_cb_(filter_buffer); + } + }, + buffer_size_); + } + std::unique_ptr listener_buffer_; + Network::MockIoHandle io_handle_; + Event::MockDispatcher dispatcher_; + uint64_t buffer_size_{512}; + ListenerFilterBufferOnDataCb on_data_cb_; + ListenerFilterBufferOnCloseCb on_close_cb_; + Event::FileReadyCb file_event_callback_; +}; + +TEST_F(ListenerFilterBufferImplTest, Basic) { + initialize(); + + // Peek 256 bytes data. + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(buffer_size_, length); + char* buf = static_cast(buffer); + for (size_t i = 0; i < length / 2; i++) { + buf[i] = 'a'; + } + return Api::IoCallUint64Result(length / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_buffer = filter_buffer.rawSlice(); + EXPECT_EQ(buffer_size_ / 2, raw_buffer.len_); + const char* buf = static_cast(raw_buffer.mem_); + for (uint64_t i = 0; i < raw_buffer.len_; i++) { + EXPECT_EQ(buf[i], 'a'); + } + }; + file_event_callback_(Event::FileReadyType::Read); + + // Peek another 256 bytes data. + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(buffer_size_, length); + char* buf = static_cast(buffer); + for (size_t i = length / 2; i < length; i++) { + buf[i] = 'b'; + } + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_buffer = filter_buffer.rawSlice(); + EXPECT_EQ(buffer_size_, raw_buffer.len_); + const char* buf = static_cast(raw_buffer.mem_); + for (uint64_t i = 0; i < buffer_size_ / 2; i++) { + EXPECT_EQ(buf[i], 'a'); + } + for (uint64_t i = buffer_size_ / 2; i < buffer_size_; i++) { + EXPECT_EQ(buf[i], 'b'); + } + }; + file_event_callback_(Event::FileReadyType::Read); + + // On socket failure + bool is_closed = false; + on_close_cb_ = [&](bool) { is_closed = true; }; + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return( + ByMove(Api::IoCallUint64Result(-1, Api::IoErrorPtr(new IoSocketError(SOCKET_ERROR_INTR), + IoSocketError::deleteIoError))))); + file_event_callback_(Event::FileReadyType::Read); + EXPECT_TRUE(is_closed); + + // On remote closed + is_closed = false; + on_close_cb_ = [&](bool) { is_closed = true; }; + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return( + ByMove(Api::IoCallUint64Result(0, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + file_event_callback_(Event::FileReadyType::Read); + EXPECT_TRUE(is_closed); + + // On socket again. + is_closed = false; + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove( + Api::IoCallUint64Result(0, Api::IoErrorPtr(IoSocketError::getIoSocketEagainInstance(), + IoSocketError::deleteIoError))))); + file_event_callback_(Event::FileReadyType::Read); + EXPECT_FALSE(is_closed); +} + +TEST_F(ListenerFilterBufferImplTest, DrainData) { + initialize(); + + // Peek 256 bytes data. + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(buffer_size_, length); + char* buf = static_cast(buffer); + for (size_t i = 0; i < length / 2; i++) { + buf[i] = 'a'; + } + return Api::IoCallUint64Result(length / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_buffer = filter_buffer.rawSlice(); + EXPECT_EQ(buffer_size_ / 2, raw_buffer.len_); + const char* buf = static_cast(raw_buffer.mem_); + for (uint64_t i = 0; i < raw_buffer.len_; i++) { + EXPECT_EQ(buf[i], 'a'); + } + }; + file_event_callback_(Event::FileReadyType::Read); + + // Drain the 128 bytes data + uint64_t drained_size = 128; + + // Drain the data from the actual socket + EXPECT_CALL(io_handle_, recv) + .WillOnce([&](void*, size_t length, int flags) { + // expect to read, not peek + EXPECT_EQ(0, flags); + // expect to read the `drained_size` data + EXPECT_EQ(drained_size, length); + // only drain half data from the socket. + return Api::IoCallUint64Result(drained_size / 2, + Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }) + .WillOnce([&](void*, size_t length, int flags) { + // expect to read, not peek + EXPECT_EQ(0, flags); + // expect to read the `drained_size` data + EXPECT_EQ(drained_size / 2, length); + return Api::IoCallUint64Result(drained_size / 2, + Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + + listener_buffer_->drain(drained_size); + // Then should only can access the last 128 bytes + auto slice1 = listener_buffer_->rawSlice(); + EXPECT_EQ(drained_size, slice1.len_); + const char* buf = static_cast(slice1.mem_); + for (uint64_t i = 0; i < drained_size; i++) { + EXPECT_EQ(buf[i], 'a'); + } + + // Peek again + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(buffer_size_, length); + char* buf = static_cast(buffer); + for (uint64_t i = 0; i < length; i++) { + buf[i] = 'b'; + } + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_slice = filter_buffer.rawSlice(); + EXPECT_EQ(buffer_size_, raw_slice.len_); + buf = static_cast(raw_slice.mem_); + for (uint64_t i = 0; i < raw_slice.len_; i++) { + EXPECT_EQ(buf[i], 'b'); + } + }; + file_event_callback_(Event::FileReadyType::Read); +} + +TEST_F(ListenerFilterBufferImplTest, ResetCapacity) { + initialize(); + + // Peek 256 bytes data. + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(buffer_size_, length); + char* buf = static_cast(buffer); + for (size_t i = 0; i < length / 2; i++) { + buf[i] = 'a'; + } + return Api::IoCallUint64Result(length / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_buffer = filter_buffer.rawSlice(); + EXPECT_EQ(buffer_size_ / 2, raw_buffer.len_); + const char* buf = static_cast(raw_buffer.mem_); + for (uint64_t i = 0; i < raw_buffer.len_; i++) { + EXPECT_EQ(buf[i], 'a'); + } + }; + file_event_callback_(Event::FileReadyType::Read); + + listener_buffer_->resetCapacity(1024); + + EXPECT_EQ(1024, listener_buffer_->capacity()); + EXPECT_EQ(0, listener_buffer_->rawSlice().len_); + + // Peek 1024 bytes data. + EXPECT_CALL(io_handle_, recv).WillOnce([&](void* buffer, size_t length, int flags) { + EXPECT_EQ(MSG_PEEK, flags); + EXPECT_EQ(1024, length); + char* buf = static_cast(buffer); + for (size_t i = 0; i < length; i++) { + buf[i] = 'b'; + } + return Api::IoCallUint64Result(length, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + + on_data_cb_ = [&](ListenerFilterBuffer& filter_buffer) { + auto raw_buffer = filter_buffer.rawSlice(); + EXPECT_EQ(1024, raw_buffer.len_); + const char* buf = static_cast(raw_buffer.mem_); + for (uint64_t i = 0; i < raw_buffer.len_; i++) { + EXPECT_EQ(buf[i], 'b'); + } + }; + file_event_callback_(Event::FileReadyType::Read); +} + +} // namespace +} // namespace Network +} // namespace Envoy diff --git a/test/extensions/filters/listener/common/fuzz/BUILD b/test/extensions/filters/listener/common/fuzz/BUILD index 760575dbe87fa..fdf01989250f7 100644 --- a/test/extensions/filters/listener/common/fuzz/BUILD +++ b/test/extensions/filters/listener/common/fuzz/BUILD @@ -1,6 +1,5 @@ load( "//bazel:envoy_build_system.bzl", - "envoy_cc_test", "envoy_cc_test_library", "envoy_package", "envoy_proto_library", @@ -23,7 +22,10 @@ envoy_cc_test_library( ":listener_filter_fakes", ":listener_filter_fuzzer_proto_cc_proto", "//envoy/network:filter_interface", + "//source/common/network:connection_balancer_lib", + "//source/server:connection_handler_lib", "//test/mocks/network:network_mocks", + "//test/test_common:network_utility_lib", "//test/test_common:threadsafe_singleton_injector_lib", ], ) @@ -37,11 +39,3 @@ envoy_cc_test_library( "//test/mocks/network:network_mocks", ], ) - -envoy_cc_test( - name = "fuzzed_input_test", - srcs = ["fuzzed_input_test.cc"], - deps = [ - ":listener_filter_fuzzer_lib", - ], -) diff --git a/test/extensions/filters/listener/common/fuzz/fuzzed_input_test.cc b/test/extensions/filters/listener/common/fuzz/fuzzed_input_test.cc deleted file mode 100644 index 856980e1fc638..0000000000000 --- a/test/extensions/filters/listener/common/fuzz/fuzzed_input_test.cc +++ /dev/null @@ -1,79 +0,0 @@ -#include - -#include "test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h" - -#include "gtest/gtest.h" - -namespace Envoy { -namespace Extensions { -namespace ListenerFilters { - -TEST(FuzzedInputStream, Empty) { - std::vector buffer; - std::vector indices; - FuzzedInputStream data(buffer, indices); - EXPECT_TRUE(data.empty()); - EXPECT_TRUE(data.done()); -} - -TEST(FuzzedInputStream, OneRead) { - std::vector buffer{'h', 'e', 'l', 'l', 'o'}; - std::vector indices{4}; - FuzzedInputStream data(buffer, indices); - EXPECT_FALSE(data.empty()); - EXPECT_EQ(data.size(), 5); - EXPECT_TRUE(data.done()); - - std::array read_data; - - // Test peeking - EXPECT_EQ(data.read(read_data.data(), 5, true).return_value_, 5); - EXPECT_EQ(data.size(), 5); - - // Test length > data.size() - EXPECT_EQ(data.read(read_data.data(), 10, true).return_value_, 5); - EXPECT_EQ(data.size(), 5); - - // Test non-peeking - EXPECT_EQ(data.read(read_data.data(), 3, false).return_value_, 3); - EXPECT_EQ(data.size(), 2); - - // Test reaching end-of-stream - EXPECT_EQ(data.read(read_data.data(), 5, false).return_value_, 2); - EXPECT_EQ(data.size(), 0); -} - -TEST(FuzzedInputStream, MultipleReads) { - std::vector buffer{'h', 'e', 'l', 'l', 'o'}; - std::vector indices{1, 3, 4}; - FuzzedInputStream data(buffer, indices); - EXPECT_FALSE(data.empty()); - EXPECT_EQ(data.size(), 2); - EXPECT_FALSE(data.done()); - - std::array read_data; - - // Test peeking (first read) - EXPECT_EQ(data.read(read_data.data(), 5, true).return_value_, 2); - EXPECT_EQ(data.size(), 2); - - data.next(); - EXPECT_FALSE(data.done()); - EXPECT_EQ(data.size(), 4); - - // Test non-peeking (second read) - EXPECT_EQ(data.read(read_data.data(), 3, false).return_value_, 3); - EXPECT_EQ(data.size(), 1); - - data.next(); - EXPECT_TRUE(data.done()); - EXPECT_EQ(data.size(), 2); - - // Test non-peeking (third read) and reaching end-of-stream - EXPECT_EQ(data.read(read_data.data(), 5, false).return_value_, 2); - EXPECT_EQ(data.size(), 0); -} - -} // namespace ListenerFilters -} // namespace Extensions -} // namespace Envoy diff --git a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.cc b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.cc index 280cfb4e5529b..88241d66d935d 100644 --- a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.cc +++ b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.cc @@ -1,5 +1,7 @@ #include "test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h" +using testing::Return; + namespace Envoy { namespace Extensions { namespace ListenerFilters { @@ -22,107 +24,61 @@ void ListenerFilterFuzzer::fuzz( Network::Utility::resolveUrl("tcp://0.0.0.0:0")); } - FuzzedInputStream data(input); - - if (!data.empty()) { - ON_CALL(os_sys_calls_, recv(kFakeSocketFd, _, _, _)) - .WillByDefault(testing::Return(Api::SysCallSizeResult{static_cast(0), 0})); - - ON_CALL(dispatcher_, createFileEvent_(_, _, _, _)) - .WillByDefault(testing::DoAll(testing::SaveArg<1>(&file_event_callback_), - testing::SaveArg<3>(&events_), - testing::ReturnNew>())); - } - filter->onAccept(cb_); - - if (file_event_callback_ == nullptr) { - // If filter does not call createFileEvent (i.e. original_dst and original_src) - return; - } - - if (!data.empty()) { - ON_CALL(os_sys_calls_, ioctl(kFakeSocketFd, FIONREAD, _, _, _, _, _)) - .WillByDefault(Invoke([&data](os_fd_t, unsigned long, void* argp, unsigned long, void*, - unsigned long, unsigned long*) -> Api::SysCallIntResult { - int bytes_avail = static_cast(data.size()); - memcpy(argp, &bytes_avail, sizeof(int)); - return Api::SysCallIntResult{bytes_avail, 0}; - })); - { - testing::InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(kFakeSocketFd, _, _, _)) - .Times(testing::AnyNumber()) - .WillRepeatedly(Invoke( - [&data](os_fd_t, void* buffer, size_t length, int flags) -> Api::SysCallSizeResult { - return data.read(buffer, length, flags == MSG_PEEK); - })); - } - - bool got_continue = false; - - ON_CALL(cb_, continueFilterChain(true)) - .WillByDefault(testing::InvokeWithoutArgs([&got_continue]() { got_continue = true; })); - - while (!got_continue) { - if (data.done()) { // End of stream reached but not done - if (events_ & Event::FileReadyType::Closed) { - file_event_callback_(Event::FileReadyType::Closed); - } - return; - } else { - file_event_callback_(Event::FileReadyType::Read); - } - - data.next(); - } - } } -FuzzedInputStream::FuzzedInputStream( - const test::extensions::filters::listener::FilterFuzzTestCase& input) - : nreads_(input.data_size()) { - size_t len = 0; - for (int i = 0; i < nreads_; i++) { - len += input.data(i).size(); - } - - data_.reserve(len); +ListenerFilterWithDataFuzzer::ListenerFilterWithDataFuzzer() + : api_(Api::createApiForTest(stats_store_)), + dispatcher_(api_->allocateDispatcher("test_thread")), + socket_(std::make_shared( + Network::Test::getCanonicalLoopbackAddress(Network::Address::IpVersion::v4))), + connection_handler_(new Server::ConnectionHandlerImpl(*dispatcher_, absl::nullopt)), + name_("proxy"), filter_chain_(Network::Test::createEmptyFilterChainWithRawBufferSockets()), + init_manager_(nullptr) { + EXPECT_CALL(socket_factory_, socketType()).WillOnce(Return(Network::Socket::Type::Stream)); + EXPECT_CALL(socket_factory_, localAddress()) + .WillRepeatedly(ReturnRef(socket_->connectionInfoProvider().localAddress())); + EXPECT_CALL(socket_factory_, getListenSocket(_)).WillOnce(Return(socket_)); + connection_handler_->addListener(absl::nullopt, *this, runtime_); + conn_ = dispatcher_->createClientConnection(socket_->connectionInfoProvider().localAddress(), + Network::Address::InstanceConstSharedPtr(), + Network::Test::createRawBufferSocket(), nullptr); + conn_->addConnectionCallbacks(connection_callbacks_); +} - for (int i = 0; i < nreads_; i++) { - data_.insert(data_.end(), input.data(i).begin(), input.data(i).end()); - indices_.push_back(data_.size() - 1); - } +void ListenerFilterWithDataFuzzer::connect(Network::ListenerFilterPtr filter) { + EXPECT_CALL(factory_, createListenerFilterChain(_)) + .WillOnce(Invoke([&](Network::ListenerFilterManager& filter_manager) -> bool { + filter_manager.addAcceptFilter(nullptr, std::move(filter)); + dispatcher_->exit(); + return true; + })); + conn_->connect(); + + EXPECT_CALL(connection_callbacks_, onEvent(Network::ConnectionEvent::Connected)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); })); + dispatcher_->run(Event::Dispatcher::RunType::Block); } -FuzzedInputStream::FuzzedInputStream(std::vector buffer, std::vector indices) - : nreads_(indices.size()), data_(std::move(buffer)), indices_(std::move(indices)) {} +void ListenerFilterWithDataFuzzer::disconnect() { + EXPECT_CALL(connection_callbacks_, onEvent(Network::ConnectionEvent::LocalClose)) + .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); })); -void FuzzedInputStream::next() { - if (!done()) { - nread_++; - } + conn_->close(Network::ConnectionCloseType::NoFlush); + dispatcher_->run(Event::Dispatcher::RunType::Block); } -Api::SysCallSizeResult FuzzedInputStream::read(void* buffer, size_t length, bool peek) { - const size_t len = std::min(size(), length); // Number of bytes to write - memcpy(buffer, data_.data() + index_, len); - - if (!peek) { - // If not peeking, written bytes will be marked as read - index_ += len; +void ListenerFilterWithDataFuzzer::fuzz( + Network::ListenerFilterPtr filter, + const test::extensions::filters::listener::FilterFuzzWithDataTestCase& input) { + connect(std::move(filter)); + for (int i = 0; i < input.data_size(); i++) { + std::string data(input.data(i).begin(), input.data(i).end()); + write(data); } - - return Api::SysCallSizeResult{static_cast(len), 0}; + disconnect(); } -size_t FuzzedInputStream::size() const { return indices_[nread_] - index_ + 1; } - -bool FuzzedInputStream::done() { return nread_ >= nreads_ - 1; } - -bool FuzzedInputStream::empty() { return nreads_ == 0 || data_.empty(); } - } // namespace ListenerFilters } // namespace Extensions } // namespace Envoy diff --git a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h index 7caab46d77883..9a52da99823c0 100644 --- a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h +++ b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.h @@ -2,13 +2,20 @@ #include "envoy/network/filter.h" +#include "source/common/network/connection_balancer_impl.h" +#include "source/server/connection_handler_impl.h" + #include "test/extensions/filters/listener/common/fuzz/listener_filter_fakes.h" #include "test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.pb.validate.h" #include "test/mocks/event/mocks.h" #include "test/mocks/network/mocks.h" +#include "test/test_common/network_utility.h" #include "test/test_common/threadsafe_singleton_injector.h" #include "gmock/gmock.h" +#include "gtest/gtest.h" + +using testing::ReturnRef; namespace Envoy { namespace Extensions { @@ -27,43 +34,80 @@ class ListenerFilterFuzzer { const test::extensions::filters::listener::FilterFuzzTestCase& input); private: - FakeOsSysCalls os_sys_calls_; - TestThreadsafeSingletonInjector os_calls_{&os_sys_calls_}; NiceMock cb_; FakeConnectionSocket socket_; NiceMock dispatcher_; - Event::FileReadyCb file_event_callback_; - uint32_t events_; envoy::config::core::v3::Metadata metadata_; }; -class FuzzedInputStream { +class ListenerFilterWithDataFuzzer : public Network::ListenerConfig, + public Network::FilterChainManager { public: - FuzzedInputStream(const test::extensions::filters::listener::FilterFuzzTestCase& input); - - FuzzedInputStream(std::vector buffer, std::vector indices); - - // Makes data from the next read available to read() - void next(); + ListenerFilterWithDataFuzzer(); + + // Network::ListenerConfig + Network::FilterChainManager& filterChainManager() override { return *this; } + Network::FilterChainFactory& filterChainFactory() override { return factory_; } + Network::ListenSocketFactory& listenSocketFactory() override { return socket_factory_; } + bool bindToPort() override { return true; } + bool handOffRestoredDestinationConnections() const override { return false; } + uint32_t perConnectionBufferLimitBytes() const override { return 0; } + std::chrono::milliseconds listenerFiltersTimeout() const override { return {}; } + bool continueOnListenerFiltersTimeout() const override { return false; } + Stats::Scope& listenerScope() override { return stats_store_; } + uint64_t listenerTag() const override { return 1; } + ResourceLimit& openConnections() override { return open_connections_; } + const std::string& name() const override { return name_; } + Network::UdpListenerConfigOptRef udpListenerConfig() override { + return Network::UdpListenerConfigOptRef(); + } + Network::InternalListenerConfigOptRef internalListenerConfig() override { + return Network::InternalListenerConfigOptRef(); + } + envoy::config::core::v3::TrafficDirection direction() const override { + return envoy::config::core::v3::UNSPECIFIED; + } + Network::ConnectionBalancer& connectionBalancer() override { return connection_balancer_; } + const std::vector& accessLogs() const override { + return empty_access_logs_; + } + uint32_t tcpBacklogSize() const override { return ENVOY_TCP_BACKLOG_SIZE; } + Init::Manager& initManager() override { return *init_manager_; } + bool ignoreGlobalConnLimit() const override { return false; } - // Copies data into buffer and returns the number of bytes written - Api::SysCallSizeResult read(void* buffer, size_t length, bool peek); + // Network::FilterChainManager + const Network::FilterChain* findFilterChain(const Network::ConnectionSocket&) const override { + return filter_chain_.get(); + } - // Returns the number of bytes currently available to read() - size_t size() const; + void write(const std::string& s) { + Buffer::OwnedImpl buf(s); + conn_->write(buf, false); + } - // Returns true if end of stream reached (no more reads) - bool done(); + void connect(Network::ListenerFilterPtr filter); + void disconnect(); - // Returns true if data field in proto is empty - bool empty(); + void fuzz(Network::ListenerFilterPtr filter, + const test::extensions::filters::listener::FilterFuzzWithDataTestCase& input); private: - const int nreads_; // Number of reads - int nread_ = 0; // Counter of current read - size_t index_ = 0; // Index of first unread byte - std::vector data_; - std::vector indices_; // Ending indices for each read + testing::NiceMock runtime_; + Stats::TestUtil::TestStore stats_store_; + Api::ApiPtr api_; + BasicResourceLimitImpl open_connections_; + Event::DispatcherPtr dispatcher_; + std::shared_ptr socket_; + Network::MockListenSocketFactory socket_factory_; + Network::NopConnectionBalancerImpl connection_balancer_; + Network::ConnectionHandlerPtr connection_handler_; + Network::MockFilterChainFactory factory_; + Network::ClientConnectionPtr conn_; + NiceMock connection_callbacks_; + std::string name_; + const Network::FilterChainSharedPtr filter_chain_; + const std::vector empty_access_logs_; + std::unique_ptr init_manager_; }; } // namespace ListenerFilters diff --git a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.proto b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.proto index 11c0fa54250b7..1109b99545cc8 100644 --- a/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.proto +++ b/test/extensions/filters/listener/common/fuzz/listener_filter_fuzzer.proto @@ -9,5 +9,8 @@ message Socket { message FilterFuzzTestCase { Socket sock = 1; +} + +message FilterFuzzWithDataTestCase { repeated bytes data = 2; } diff --git a/test/extensions/filters/listener/http_inspector/BUILD b/test/extensions/filters/listener/http_inspector/BUILD index 240d7dab7745a..400b1658a6714 100644 --- a/test/extensions/filters/listener/http_inspector/BUILD +++ b/test/extensions/filters/listener/http_inspector/BUILD @@ -23,6 +23,8 @@ envoy_extension_cc_test( deps = [ "//source/common/common:hex_lib", "//source/common/http:utility_lib", + "//source/common/network:default_socket_interface_lib", + "//source/common/network:listener_filter_buffer_lib", "//source/extensions/filters/listener/http_inspector:http_inspector_lib", "//test/mocks/api:api_mocks", "//test/mocks/network:network_mocks", diff --git a/test/extensions/filters/listener/http_inspector/http_inspector_fuzz_test.cc b/test/extensions/filters/listener/http_inspector/http_inspector_fuzz_test.cc index d8ea3d2b2b4d8..b1d1e85ffecd0 100644 --- a/test/extensions/filters/listener/http_inspector/http_inspector_fuzz_test.cc +++ b/test/extensions/filters/listener/http_inspector/http_inspector_fuzz_test.cc @@ -9,7 +9,7 @@ namespace Extensions { namespace ListenerFilters { namespace HttpInspector { -DEFINE_PROTO_FUZZER(const test::extensions::filters::listener::FilterFuzzTestCase& input) { +DEFINE_PROTO_FUZZER(const test::extensions::filters::listener::FilterFuzzWithDataTestCase& input) { try { TestUtility::validate(input); } catch (const ProtoValidationException& e) { @@ -21,7 +21,7 @@ DEFINE_PROTO_FUZZER(const test::extensions::filters::listener::FilterFuzzTestCas ConfigSharedPtr cfg = std::make_shared(store); auto filter = std::make_unique(cfg); - ListenerFilterFuzzer fuzzer; + ListenerFilterWithDataFuzzer fuzzer; fuzzer.fuzz(std::move(filter), input); } diff --git a/test/extensions/filters/listener/http_inspector/http_inspector_test.cc b/test/extensions/filters/listener/http_inspector/http_inspector_test.cc index f5e68eae1d9a2..0d17553a41e91 100644 --- a/test/extensions/filters/listener/http_inspector/http_inspector_test.cc +++ b/test/extensions/filters/listener/http_inspector/http_inspector_test.cc @@ -1,6 +1,7 @@ #include "source/common/common/hex.h" #include "source/common/http/utility.h" #include "source/common/network/io_socket_handle_impl.h" +#include "source/common/network/listener_filter_buffer_impl.h" #include "source/extensions/filters/listener/http_inspector/http_inspector.h" #include "test/mocks/api/mocks.h" @@ -30,10 +31,11 @@ class HttpInspectorTest : public testing::Test { public: HttpInspectorTest() : cfg_(std::make_shared(store_)), - io_handle_(std::make_unique(42)) {} + io_handle_( + Network::SocketInterfaceImpl::makePlatformSpecificSocket(42, false, absl::nullopt)) {} ~HttpInspectorTest() override { io_handle_->close(); } - void init(bool include_inline_recv = true) { + void init() { filter_ = std::make_unique(cfg_); EXPECT_CALL(cb_, socket()).WillRepeatedly(ReturnRef(socket_)); @@ -41,20 +43,287 @@ class HttpInspectorTest : public testing::Test { EXPECT_CALL(cb_, dispatcher()).WillRepeatedly(ReturnRef(dispatcher_)); EXPECT_CALL(testing::Const(socket_), ioHandle()).WillRepeatedly(ReturnRef(*io_handle_)); EXPECT_CALL(socket_, ioHandle()).WillRepeatedly(ReturnRef(*io_handle_)); + EXPECT_CALL(dispatcher_, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read)) + .WillOnce( + DoAll(SaveArg<1>(&file_event_callback_), ReturnNew>())); + buffer_ = std::make_unique( + *io_handle_, dispatcher_, [](bool) {}, [](Network::ListenerFilterBuffer&) {}, + filter_->maxReadBytes()); + } - if (include_inline_recv) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + void testHttpInspectMultipleReadsSuccess(absl::string_view header, bool http2 = false) { + init(); + const std::vector data = Hex::decode(std::string(header)); + { + InSequence s; + +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + if (http2) { + for (size_t i = 0; i < data.size(); i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&data, i](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= data.size()); + memcpy(iov->iov_base, data.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } + } else { + for (size_t i = 0; i < header.size(); i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&header, i](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= header.size()); + memcpy(iov->iov_base, header.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } + } +#else + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { + return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; + })); + if (http2) { + for (size_t i = 1; i <= data.size(); i++) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke( + [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= i); + memcpy(buffer, data.data(), i); + return Api::SysCallSizeResult{ssize_t(i), 0}; + })); + } + } else { + for (size_t i = 1; i <= header.size(); i++) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke([&header, i](os_fd_t, void* buffer, size_t length, + int) -> Api::SysCallSizeResult { + ASSERT(length >= i); + memcpy(buffer, header.data(), i); + return Api::SysCallSizeResult{ssize_t(i), 0}; + })); + } + } +#endif + } + bool got_continue = false; + EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); + EXPECT_CALL(socket_, close()).WillOnce(InvokeWithoutArgs([&got_continue]() { + got_continue = true; + })); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); + while (!got_continue) { + file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + EXPECT_EQ(status, Network::FilterStatus::StopIteration); + } - EXPECT_CALL(dispatcher_, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, - Event::FileReadyType::Read)) - .WillOnce(DoAll(SaveArg<1>(&file_event_callback_), - ReturnNew>())); + EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + } + + void testHttpInspectMultipleReadsSuccess(absl::string_view header, absl::string_view alpn) { + init(); + const std::vector alpn_protos{alpn}; + const std::vector data = Hex::decode(std::string(header)); + { + InSequence s; + +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + if (alpn == Http::Utility::AlpnNames::get().Http2c) { + for (size_t i = 0; i < 24; i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&data, i](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= data.size()); + memcpy(iov->iov_base, data.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } + } else { + for (size_t i = 0; i < header.size(); i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&header, i](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= header.size()); + memcpy(iov->iov_base, header.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } + } +#else + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { + return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; + })); - filter_->onAccept(cb_); + if (alpn == Http::Utility::AlpnNames::get().Http2c) { + for (size_t i = 1; i <= 24; i++) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke( + [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= i); + memcpy(buffer, data.data(), i); + return Api::SysCallSizeResult{ssize_t(i), 0}; + })); + } + } else { + for (size_t i = 1; i <= header.size(); i++) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke([&header, i](os_fd_t, void* buffer, size_t length, + int) -> Api::SysCallSizeResult { + ASSERT(length >= i); + memcpy(buffer, header.data(), i); + return Api::SysCallSizeResult{ssize_t(i), 0}; + })); + } + } +#endif + } + + bool got_continue = false; + EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); + while (!got_continue) { + file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + if (status == Network::FilterStatus::Continue) { + got_continue = true; + } + } + if (alpn == Http::Utility::AlpnNames::get().Http11) { + EXPECT_EQ(1, cfg_->stats().http11_found_.value()); + } else if (alpn == Http::Utility::AlpnNames::get().Http10) { + EXPECT_EQ(1, cfg_->stats().http10_found_.value()); + } else if (alpn == Http::Utility::AlpnNames::get().Http2c) { + EXPECT_EQ(1, cfg_->stats().http2_found_.value()); + } else { + EXPECT_EQ(alpn, "unknow alpn"); + } + } + + void testHttpInspectSuccess(absl::string_view header, absl::string_view alpn) { + init(); + std::vector data = Hex::decode(std::string(header)); +#ifdef WIN32 + if (alpn == Http::Utility::AlpnNames::get().Http2c) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce( + Invoke([&data](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= data.size()); + memcpy(iov->iov_base, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } else { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce( + Invoke([&header](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= header.size()); + memcpy(iov->iov_base, header.data(), header.size()); + return Api::SysCallSizeResult{ssize_t(header.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } +#else + if (alpn == Http::Utility::AlpnNames::get().Http2c) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce( + Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= data.size()); + memcpy(buffer, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })); + } else { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke( + [&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= header.size()); + memcpy(buffer, header.data(), header.size()); + return Api::SysCallSizeResult{ssize_t(header.size()), 0}; + })); + } +#endif + const std::vector alpn_protos{alpn}; + + EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); + file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + EXPECT_EQ(status, Network::FilterStatus::Continue); + if (alpn == Http::Utility::AlpnNames::get().Http11) { + EXPECT_EQ(1, cfg_->stats().http11_found_.value()); + } else if (alpn == Http::Utility::AlpnNames::get().Http10) { + EXPECT_EQ(1, cfg_->stats().http10_found_.value()); + } else if (alpn == Http::Utility::AlpnNames::get().Http2c) { + EXPECT_EQ(1, cfg_->stats().http2_found_.value()); + } else { + EXPECT_EQ(alpn, "unknow alpn"); } } + void testHttpInspectFail(absl::string_view header, bool http2 = false) { + init(); + std::vector data = Hex::decode(std::string(header)); +#ifdef WIN32 + if (http2) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce( + Invoke([&data](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= data.size()); + memcpy(iov->iov_base, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } else { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce( + Invoke([&header](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= header.size()); + memcpy(iov->iov_base, header.data(), header.size()); + return Api::SysCallSizeResult{ssize_t(header.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } +#else + if (http2) { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce( + Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= data.size()); + memcpy(buffer, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })); + } else { + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke( + [&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= header.size()); + memcpy(buffer, header.data(), header.size()); + return Api::SysCallSizeResult{ssize_t(header.size()), 0}; + })); + } +#endif + EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); + EXPECT_CALL(socket_, close()); + file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + EXPECT_EQ(status, Network::FilterStatus::StopIteration); + EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + } + NiceMock os_sys_calls_; TestThreadsafeSingletonInjector os_calls_{&os_sys_calls_}; Stats::IsolatedStoreImpl store_; @@ -65,6 +334,7 @@ class HttpInspectorTest : public testing::Test { NiceMock dispatcher_; Event::FileReadyCb file_event_callback_; Network::IoHandlePtr io_handle_; + std::unique_ptr buffer_; }; TEST_F(HttpInspectorTest, SkipHttpInspectForTLS) { @@ -76,276 +346,101 @@ TEST_F(HttpInspectorTest, SkipHttpInspectForTLS) { EXPECT_EQ(filter_->onAccept(cb_), Network::FilterStatus::Continue); } -TEST_F(HttpInspectorTest, InlineReadIoError) { - init(/*include_inline_recv=*/false); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke([](os_fd_t, void*, size_t, int) -> Api::SysCallSizeResult { - return Api::SysCallSizeResult{ssize_t(-1), 0}; - })); - EXPECT_CALL(dispatcher_, createFileEvent_(_, _, _, _)).Times(0); - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(socket_, close()); - auto accepted = filter_->onAccept(cb_); - EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); - // It's arguable if io error should bump the not_found counter - EXPECT_EQ(0, cfg_->stats().http_not_found_.value()); -} - TEST_F(HttpInspectorTest, InlineReadInspectHttp10) { - init(/*include_inline_recv=*/false); const absl::string_view header = "GET /anything HTTP/1.0\r\nhost: google.com\r\nuser-agent: curl/7.64.0\r\naccept: " "*/*\r\nx-forwarded-proto: http\r\nx-request-id: " "a52df4a0-ed00-4a19-86a7-80e5049c6c84\r\nx-envoy-expected-rq-timeout-ms: " "15000\r\ncontent-length: 0\r\n\r\n"; - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; - - EXPECT_CALL(dispatcher_, createFileEvent_(_, _, _, _)).Times(0); - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - auto accepted = filter_->onAccept(cb_); - EXPECT_EQ(accepted, Network::FilterStatus::Continue); - EXPECT_EQ(1, cfg_->stats().http10_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http10); } TEST_F(HttpInspectorTest, InlineReadParseError) { - init(/*include_inline_recv=*/false); const absl::string_view header = "NOT_A_LEGAL_PREFIX /anything HTTP/1.0\r\nhost: google.com\r\nuser-agent: " "curl/7.64.0\r\naccept: " "*/*\r\nx-forwarded-proto: http\r\nx-request-id: " "a52df4a0-ed00-4a19-86a7-80e5049c6c84\r\nx-envoy-expected-rq-timeout-ms: " "15000\r\ncontent-length: 0\r\n\r\n"; - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - EXPECT_CALL(dispatcher_, createFileEvent_(_, _, _, _)).Times(0); - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - auto accepted = filter_->onAccept(cb_); - EXPECT_EQ(accepted, Network::FilterStatus::Continue); - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectFail(header); } TEST_F(HttpInspectorTest, InspectHttp10) { - init(true); const absl::string_view header = "GET /anything HTTP/1.0\r\nhost: google.com\r\nuser-agent: curl/7.64.0\r\naccept: " "*/*\r\nx-forwarded-proto: http\r\nx-request-id: " "a52df4a0-ed00-4a19-86a7-80e5049c6c84\r\nx-envoy-expected-rq-timeout-ms: " "15000\r\ncontent-length: 0\r\n\r\n"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http10_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http10); } TEST_F(HttpInspectorTest, InspectHttp11) { - init(); const absl::string_view header = "GET /anything HTTP/1.1\r\nhost: google.com\r\nuser-agent: curl/7.64.0\r\naccept: " "*/*\r\nx-forwarded-proto: http\r\nx-request-id: " "a52df4a0-ed00-4a19-86a7-80e5049c6c84\r\nx-envoy-expected-rq-timeout-ms: " "15000\r\ncontent-length: 0\r\n\r\n"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http11}; - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http11_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http11); } TEST_F(HttpInspectorTest, InspectHttp11WithNonEmptyRequestBody) { - init(); const absl::string_view header = "GET /anything HTTP/1.1\r\nhost: google.com\r\nuser-agent: curl/7.64.0\r\naccept: " "*/*\r\nx-forwarded-proto: http\r\nx-request-id: " "a52df4a0-ed00-4a19-86a7-80e5049c6c84\r\nx-envoy-expected-rq-timeout-ms: " "15000\r\ncontent-length: 3\r\n\r\nfoo"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http11}; - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http11_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http11); } TEST_F(HttpInspectorTest, ExtraSpaceInRequestLine) { - init(); const absl::string_view header = "GET /anything HTTP/1.1\r\n\r\n"; - // ^^ ^^ - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http11}; - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http11_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http11); } TEST_F(HttpInspectorTest, InvalidHttpMethod) { - init(); const absl::string_view header = "BAD /anything HTTP/1.1"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(0, cfg_->stats().http11_found_.value()); + testHttpInspectFail(header); } TEST_F(HttpInspectorTest, InvalidHttpRequestLine) { - init(); const absl::string_view header = "BAD /anything HTTP/1.1\r\n"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(_)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectFail(header); } TEST_F(HttpInspectorTest, OldHttpProtocol) { - init(); const absl::string_view header = "GET /anything HTTP/0.9\r\n"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http10_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http10); } TEST_F(HttpInspectorTest, InvalidRequestLine) { - init(); const absl::string_view header = "GET /anything HTTP/1.1 BadRequestLine\r\n"; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&header](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= header.size()); - memcpy(buffer, header.data(), header.size()); - return Api::SysCallSizeResult{ssize_t(header.size()), 0}; - })); - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectFail(header); } TEST_F(HttpInspectorTest, InspectHttp2) { - init(); - const std::string header = "505249202a20485454502f322e300d0a0d0a534d0d0a0d0a00000c04000000000000041000000000020000000000" "00040800000000000fff000100007d010500000001418aa0e41d139d09b8f0000f048860757a4ce6aa660582867a" "8825b650c3abb8d2e053032a2f2a408df2b4a7b3c0ec90b22d5d8749ff839d29af4089f2b585ed6950958d279a18" "9e03f1ca5582265f59a75b0ac3111959c7e49004908db6e83f4096f2b16aee7f4b17cd65224b22d6765926a4a7b5" "2b528f840b60003f"; - std::vector data = Hex::decode(header); - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= data.size()); - memcpy(buffer, data.data(), data.size()); - return Api::SysCallSizeResult{ssize_t(data.size()), 0}; - })); - - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http2c}; - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().http2_found_.value()); -} - -TEST_F(HttpInspectorTest, ReadClosed) { - init(); - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Return(Api::SysCallSizeResult{0, 0})); - EXPECT_CALL(cb_, continueFilterChain(false)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(0, cfg_->stats().http2_found_.value()); + testHttpInspectSuccess(header, Http::Utility::AlpnNames::get().Http2c); } TEST_F(HttpInspectorTest, InvalidConnectionPreface) { init(); const std::string header = "505249202a20485454502f322e300d0a"; - const std::vector data = Hex::decode(header); - + std::vector data = Hex::decode(std::string(header)); +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke([&data](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= data.size()); + memcpy(iov->iov_base, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); +#else EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) .WillOnce( Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { @@ -353,227 +448,46 @@ TEST_F(HttpInspectorTest, InvalidConnectionPreface) { memcpy(buffer, data.data(), data.size()); return Api::SysCallSizeResult{ssize_t(data.size()), 0}; })); - +#endif EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)).Times(0); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + EXPECT_EQ(status, Network::FilterStatus::StopIteration); EXPECT_EQ(0, cfg_->stats().http_not_found_.value()); } -TEST_F(HttpInspectorTest, ReadError) { - init(); - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_NOT_SUP}; - })); - EXPECT_CALL(cb_, continueFilterChain(false)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().read_error_.value()); -} - TEST_F(HttpInspectorTest, MultipleReadsHttp2) { - init(); - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http2c}; - const std::string header = "505249202a20485454502f322e300d0a0d0a534d0d0a0d0a00000c04000000000000041000000000020000000000" "00040800000000000fff000100007d010500000001418aa0e41d139d09b8f0000f048860757a4ce6aa660582867a" "8825b650c3abb8d2e053032a2f2a408df2b4a7b3c0ec90b22d5d8749ff839d29af4089f2b585ed6950958d279a18" "9e03f1ca5582265f59a75b0ac3111959c7e49004908db6e83f4096f2b16aee7f4b17cd65224b22d6765926a4a7b5" "2b528f840b60003f"; - const std::vector data = Hex::decode(header); - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= 24; i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= i); - memcpy(buffer, data.data(), i); - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - bool got_continue = false; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); - while (!got_continue) { - file_event_callback_(Event::FileReadyType::Read); - } - EXPECT_EQ(1, cfg_->stats().http2_found_.value()); + testHttpInspectMultipleReadsSuccess(header, Http::Utility::AlpnNames::get().Http2c); } TEST_F(HttpInspectorTest, MultipleReadsHttp2BadPreface) { - init(); const std::string header = "505249202a20485454502f322e300d0a0d0c"; - const std::vector data = Hex::decode(header); - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= data.size(); i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= i); - memcpy(buffer, data.data(), i); - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - bool got_continue = false; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); - while (!got_continue) { - file_event_callback_(Event::FileReadyType::Read); - } - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectMultipleReadsSuccess(header, true); } TEST_F(HttpInspectorTest, MultipleReadsHttp1) { - init(); const absl::string_view data = "GET /anything HTTP/1.0\r"; - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= data.size(); i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= i); - memcpy(buffer, data.data(), i); - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - bool got_continue = false; - const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); - while (!got_continue) { - file_event_callback_(Event::FileReadyType::Read); - } - EXPECT_EQ(1, cfg_->stats().http10_found_.value()); -} - -TEST_F(HttpInspectorTest, MultipleReadsHttp1IncompleteHeader) { - init(); - const absl::string_view data = "GE"; - bool end_stream = false; - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= data.size(); i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke([&data, &end_stream, i](os_fd_t, void* buffer, size_t length, - int) -> Api::SysCallSizeResult { - ASSERT(length >= i); - memcpy(buffer, data.data(), i); - if (i == data.size()) { - end_stream = true; - } - - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_EQ(0, cfg_->stats().http_not_found_.value()); - while (!end_stream) { - file_event_callback_(Event::FileReadyType::Read); - } + testHttpInspectMultipleReadsSuccess(data, Http::Utility::AlpnNames::get().Http10); } TEST_F(HttpInspectorTest, MultipleReadsHttp1IncompleteBadHeader) { - init(); const absl::string_view data = "X"; - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= data.size(); i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&data, i](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= i); - memcpy(buffer, data.data(), i); - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - bool got_continue = false; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); - while (!got_continue) { - file_event_callback_(Event::FileReadyType::Read); - } - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectMultipleReadsSuccess(data); } TEST_F(HttpInspectorTest, MultipleReadsHttp1BadProtocol) { - init(); const std::string valid_header = "GET /index HTTP/1.1\r"; // offset: 0 10 const std::string truncate_header = valid_header.substr(0, 14).append("\r"); - { - InSequence s; - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); - - for (size_t i = 1; i <= truncate_header.size(); i++) { - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke([&truncate_header, i](os_fd_t, void* buffer, size_t length, - int) -> Api::SysCallSizeResult { - ASSERT(length >= truncate_header.size()); - memcpy(buffer, truncate_header.data(), truncate_header.size()); - return Api::SysCallSizeResult{ssize_t(i), 0}; - })); - } - } - - bool got_continue = false; - EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); - while (!got_continue) { - file_event_callback_(Event::FileReadyType::Read); - } - EXPECT_EQ(1, cfg_->stats().http_not_found_.value()); + testHttpInspectMultipleReadsSuccess(truncate_header); } TEST_F(HttpInspectorTest, Http1WithLargeRequestLine) { @@ -586,10 +500,14 @@ TEST_F(HttpInspectorTest, Http1WithLargeRequestLine) { const std::string data = absl::StrCat(method, spaces, http); { InSequence s; - +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); +#else EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; })); +#endif uint64_t num_loops = Config::MAX_INSPECT_SIZE; #if defined(__has_feature) && \ @@ -597,6 +515,20 @@ TEST_F(HttpInspectorTest, Http1WithLargeRequestLine) { num_loops = 2; #endif +#ifdef WIN32 + auto ctr = std::make_shared(0); + auto copy_len = std::make_shared(1); + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .Times(num_loops) + .WillRepeatedly( + Invoke([&data, ctr, copy_len, num_loops](os_fd_t fd, const iovec* iov, + int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= 1); + memcpy(iov->iov_base, data.data() + *ctr, 1); + *ctr += 1; + return Api::SysCallSizeResult{ssize_t(1), 0}; + })); +#else auto ctr = std::make_shared(1); EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) .Times(num_loops) @@ -612,16 +544,20 @@ TEST_F(HttpInspectorTest, Http1WithLargeRequestLine) { *ctr += 1; return Api::SysCallSizeResult{ssize_t(len), 0}; })); +#endif } bool got_continue = false; const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); while (!got_continue) { file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + if (status == Network::FilterStatus::Continue) { + got_continue = true; + } } EXPECT_EQ(1, cfg_->stats().http10_found_.value()); } @@ -632,9 +568,23 @@ TEST_F(HttpInspectorTest, Http1WithLargeHeader) { // 0 20 std::string value(Config::MAX_INSPECT_SIZE - request.size(), 'a'); const std::string data = absl::StrCat(request, value); + { InSequence s; - +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + for (size_t i = 0; i < 20; i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&data, i](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= 20); + memcpy(iov->iov_base, data.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } +#else EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; })); @@ -648,16 +598,20 @@ TEST_F(HttpInspectorTest, Http1WithLargeHeader) { return Api::SysCallSizeResult{ssize_t(i), 0}; })); } +#endif } bool got_continue = false; const std::vector alpn_protos{Http::Utility::AlpnNames::get().Http10}; EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); + auto accepted = filter_->onAccept(cb_); + EXPECT_EQ(accepted, Network::FilterStatus::StopIteration); while (!got_continue) { file_event_callback_(Event::FileReadyType::Read); + auto status = filter_->onData(*buffer_); + if (status == Network::FilterStatus::Continue) { + got_continue = true; + } } EXPECT_EQ(1, cfg_->stats().http10_found_.value()); } diff --git a/test/extensions/filters/listener/original_src/original_src_test.cc b/test/extensions/filters/listener/original_src/original_src_test.cc index 45c4d2cb0d5fb..539c0f71bacbd 100644 --- a/test/extensions/filters/listener/original_src/original_src_test.cc +++ b/test/extensions/filters/listener/original_src/original_src_test.cc @@ -62,6 +62,7 @@ class OriginalSrcTest : public testing::Test { TEST_F(OriginalSrcTest, OnNewConnectionUnixSocketSkips) { auto filter = makeDefaultFilter(); + EXPECT_EQ(filter->maxReadBytes(), 0); setAddressToReturn("unix://domain.socket"); EXPECT_CALL(callbacks_.socket_, addOption_(_)).Times(0); EXPECT_EQ(filter->onAccept(callbacks_), Network::FilterStatus::Continue); diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.cc index 7416e92c8bba4..ff844a447d81b 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.cc @@ -22,7 +22,7 @@ DEFINE_PROTO_FUZZER( ConfigSharedPtr cfg = std::make_shared(store, input.config()); auto filter = std::make_unique(std::move(cfg)); - ListenerFilterFuzzer fuzzer; + ListenerFilterWithDataFuzzer fuzzer; fuzzer.fuzz(std::move(filter), input.fuzzed()); } diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.proto b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.proto index bab715e0e34f2..608de9f4ee7eb 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.proto +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_fuzz_test.proto @@ -9,6 +9,6 @@ import "validate/validate.proto"; message ProxyProtocolTestCase { envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol config = 1 [(validate.rules).message.required = true]; - test.extensions.filters.listener.FilterFuzzTestCase fuzzed = 2 + test.extensions.filters.listener.FilterFuzzWithDataTestCase fuzzed = 2 [(validate.rules).message.required = true]; } diff --git a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc index d63c6e7a8ca18..6efa2b9cc120c 100644 --- a/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc +++ b/test/extensions/filters/listener/proxy_protocol/proxy_protocol_test.cc @@ -179,12 +179,14 @@ class ProxyProtocolTest : public testing::TestWithParamrun(Event::Dispatcher::RunType::Block); } - void expectProxyProtoError() { + void expectConnectionError() { EXPECT_CALL(connection_callbacks_, onEvent(Network::ConnectionEvent::RemoteClose)) .WillOnce(Invoke([&](Network::ConnectionEvent) -> void { dispatcher_->exit(); })); dispatcher_->run(Event::Dispatcher::RunType::Block); - + } + void expectProxyProtoError() { + expectConnectionError(); EXPECT_EQ(stats_store_.counter("downstream_cx_proxy_proto_error").value(), 1); } @@ -381,8 +383,7 @@ TEST_P(ProxyProtocolTest, ErrorRecv_2) { connect(false); write(buffer, sizeof(buffer)); - errno = 0; - expectProxyProtoError(); + expectConnectionError(); } TEST_P(ProxyProtocolTest, ErrorRecv_1) { @@ -457,7 +458,7 @@ TEST_P(ProxyProtocolTest, ErrorRecv_1) { connect(false); write(buffer, sizeof(buffer)); - expectProxyProtoError(); + expectConnectionError(); } TEST_P(ProxyProtocolTest, V2NotLocalOrOnBehalf) { @@ -637,18 +638,34 @@ TEST_P(ProxyProtocolTest, V2ParseExtensionsRecvError) { Api::MockOsSysCalls os_sys_calls; TestThreadsafeSingletonInjector os_calls(&os_sys_calls); - + bool header_writed = false; // TODO(davinci26): Mocking should not be used to provide real system calls. +#ifdef WIN32 + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([&](os_fd_t fd, const iovec* iov, int iovcnt) { + const Api::SysCallSizeResult x = os_sys_calls_actual_.readv(fd, iov, iovcnt); + if (header_writed) { + return Api::SysCallSizeResult{-1, 0}; + } + return x; + })); +#else EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, void* buf, size_t n, int flags) { + .WillRepeatedly(Invoke([&](os_fd_t fd, void* buf, size_t n, int flags) { const Api::SysCallSizeResult x = os_sys_calls_actual_.recv(fd, buf, n, flags); - if (x.return_value_ == sizeof(tlv)) { + if (header_writed) { return Api::SysCallSizeResult{-1, 0}; - } else { - return x; } + return x; })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); +#endif EXPECT_CALL(os_sys_calls, connect(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { @@ -659,11 +676,6 @@ TEST_P(ProxyProtocolTest, V2ParseExtensionsRecvError) { .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { return os_sys_calls_actual_.writev(fd, iov, iovcnt); })); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) - .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { - return os_sys_calls_actual_.readv(fd, iov, iovcnt); - })); EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke( @@ -702,9 +714,10 @@ TEST_P(ProxyProtocolTest, V2ParseExtensionsRecvError) { connect(false); write(buffer, sizeof(buffer)); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + header_writed = true; write(tlv, sizeof(tlv)); - expectProxyProtoError(); + expectConnectionError(); } TEST_P(ProxyProtocolTest, V2ParseExtensionsFrag) { @@ -826,25 +839,28 @@ TEST_P(ProxyProtocolTest, V2Fragmented4Error) { Api::MockOsSysCalls os_sys_calls; TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + bool partial_writed = false; // TODO(davinci26): Mocking should not be used to provide real system calls. #ifdef WIN32 EXPECT_CALL(os_sys_calls, readv(_, _, _)) .Times(AnyNumber()) - .WillOnce(Invoke([&](os_fd_t fd, const iovec* iov, int num_iov) { + .WillRepeatedly(Invoke([&](os_fd_t fd, const iovec* iov, int num_iov) { const Api::SysCallSizeResult x = os_sys_calls_actual_.readv(fd, iov, num_iov); + if (partial_writed) { + return Api::SysCallSizeResult{-1, 0}; + } return x; - })) - .WillRepeatedly(Return(Api::SysCallSizeResult{-1, 0})); + })); #else EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, void* buf, size_t len, int flags) { - return os_sys_calls_actual_.recv(fd, buf, len, flags); + .WillRepeatedly(Invoke([&](os_fd_t fd, void* buf, size_t n, int flags) { + const Api::SysCallSizeResult x = os_sys_calls_actual_.recv(fd, buf, n, flags); + if (partial_writed) { + return Api::SysCallSizeResult{-1, 0}; + } + return x; })); - EXPECT_CALL(os_sys_calls, recv(_, _, 1, _)) - .Times(AnyNumber()) - .WillOnce(Return(Api::SysCallSizeResult{-1, 0})); - EXPECT_CALL(os_sys_calls, readv(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { @@ -898,8 +914,11 @@ TEST_P(ProxyProtocolTest, V2Fragmented4Error) { })); connect(false); write(buffer, 17); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + partial_writed = true; + write(buffer, 11); - expectProxyProtoError(); + expectConnectionError(); } TEST_P(ProxyProtocolTest, V2Fragmented5Error) { @@ -913,9 +932,9 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { Api::MockOsSysCalls os_sys_calls; TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + bool partial_write = false; // TODO(davinci26): Mocking should not be used to provide real system calls. #ifdef WIN32 - bool partial_write = false; EXPECT_CALL(os_sys_calls, readv(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([&](os_fd_t fd, const iovec* iov, int num_iov) { @@ -928,12 +947,13 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { #else EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) .Times(AnyNumber()) - .WillRepeatedly(Invoke([this](os_fd_t fd, void* buf, size_t len, int flags) { - return os_sys_calls_actual_.recv(fd, buf, len, flags); + .WillRepeatedly(Invoke([&](os_fd_t fd, void* buf, size_t n, int flags) { + const Api::SysCallSizeResult x = os_sys_calls_actual_.recv(fd, buf, n, flags); + if (partial_write) { + return Api::SysCallSizeResult{-1, 0}; + } + return x; })); - EXPECT_CALL(os_sys_calls, recv(_, _, 4, _)) - .Times(AnyNumber()) - .WillOnce(Return(Api::SysCallSizeResult{-1, 0})); EXPECT_CALL(os_sys_calls, readv(_, _, _)) .Times(AnyNumber()) .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { @@ -988,12 +1008,10 @@ TEST_P(ProxyProtocolTest, V2Fragmented5Error) { connect(false); write(buffer, 10); dispatcher_->run(Event::Dispatcher::RunType::NonBlock); -#ifdef WIN32 partial_write = true; -#endif write(buffer + 10, 10); - expectProxyProtoError(); + expectConnectionError(); } TEST_P(ProxyProtocolTest, PartialRead) { @@ -1044,6 +1062,47 @@ TEST_P(ProxyProtocolTest, V2PartialRead) { const std::string ProxyProtocol = "envoy.filters.listener.proxy_protocol"; +TEST_P(ProxyProtocolTest, V2ParseExtensionsLargeThanInitMaxReadBytes) { + // A well-formed ipv4/tcp with a pair of TLV extensions is accepted + constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, + 0x54, 0x0a, 0x21, 0x11, 0xff, 0xff, 0x01, 0x02, 0x03, 0x04, + 0x00, 0x01, 0x01, 0x02, 0x03, 0x05, 0x00, 0x02}; + // The TLV has 65520 size data. + constexpr uint8_t tlv_begin[] = {0x02, 0xff, 0xf0}; + std::string tlv_data(65520, 'a'); + + constexpr uint8_t data[] = {'D', 'A', 'T', 'A'}; + + envoy::extensions::filters::listener::proxy_protocol::v3::ProxyProtocol proto_config; + auto rule = proto_config.add_rules(); + rule->set_tlv_type(0x02); + rule->mutable_on_tlv_present()->set_key("PP2 type authority"); + + connect(true, &proto_config); + write(buffer, sizeof(buffer)); + dispatcher_->run(Event::Dispatcher::RunType::NonBlock); + + write(tlv_begin, sizeof(tlv_begin)); + write(tlv_data); + + write(data, sizeof(data)); + expectData("DATA"); + + EXPECT_EQ(1, server_connection_->streamInfo().dynamicMetadata().filter_metadata_size()); + auto metadata = server_connection_->streamInfo().dynamicMetadata().filter_metadata(); + EXPECT_EQ(1, metadata.size()); + EXPECT_EQ(1, metadata.count(ProxyProtocol)); + + auto fields = metadata.at(ProxyProtocol).fields(); + EXPECT_EQ(1, fields.size()); + + EXPECT_EQ(1, fields.count("PP2 type authority")); + auto value_s = fields.at("PP2 type authority").string_value(); + EXPECT_EQ(tlv_data, value_s); + + disconnect(); +} + TEST_P(ProxyProtocolTest, V2ExtractTlvOfInterest) { // A well-formed ipv4/tcp with a pair of TLV extensions is accepted constexpr uint8_t buffer[] = {0x0d, 0x0a, 0x0d, 0x0a, 0x00, 0x0d, 0x0a, 0x51, 0x55, 0x49, @@ -1392,6 +1451,81 @@ TEST_P(ProxyProtocolTest, ClosedEmpty) { dispatcher_->run(Event::Dispatcher::RunType::NonBlock); } +// There is no chance to have error for Windows since it emulate the drain +// from a memory buffer. +#ifndef WIN32 +TEST_P(ProxyProtocolTest, DrainError) { + Api::MockOsSysCalls os_sys_calls; + TestThreadsafeSingletonInjector os_calls(&os_sys_calls); + + EXPECT_CALL(os_sys_calls, recv(_, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([&](os_fd_t fd, void* buf, size_t n, int flags) { + if (flags != MSG_PEEK) { + return Api::SysCallSizeResult{-1, 0}; + } else { + const Api::SysCallSizeResult x = os_sys_calls_actual_.recv(fd, buf, n, flags); + return x; + } + })); + EXPECT_CALL(os_sys_calls, readv(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.readv(fd, iov, iovcnt); + })); + EXPECT_CALL(os_sys_calls, connect(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t sockfd, const sockaddr* addr, socklen_t addrlen) { + return os_sys_calls_actual_.connect(sockfd, addr, addrlen); + })); + EXPECT_CALL(os_sys_calls, writev(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](os_fd_t fd, const iovec* iov, int iovcnt) { + return os_sys_calls_actual_.writev(fd, iov, iovcnt); + })); + EXPECT_CALL(os_sys_calls, getsockopt_(_, _, _, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [this](os_fd_t sockfd, int level, int optname, void* optval, socklen_t* optlen) -> int { + return os_sys_calls_actual_.getsockopt(sockfd, level, optname, optval, optlen) + .return_value_; + })); + EXPECT_CALL(os_sys_calls, getsockname(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [this](os_fd_t sockfd, sockaddr* name, socklen_t* namelen) -> Api::SysCallIntResult { + return os_sys_calls_actual_.getsockname(sockfd, name, namelen); + })); + EXPECT_CALL(os_sys_calls, shutdown(_, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [this](os_fd_t sockfd, int how) { return os_sys_calls_actual_.shutdown(sockfd, how); })); + EXPECT_CALL(os_sys_calls, close(_)).Times(AnyNumber()).WillRepeatedly(Invoke([this](os_fd_t fd) { + return os_sys_calls_actual_.close(fd); + })); + EXPECT_CALL(os_sys_calls, accept(_, _, _)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke( + [this](os_fd_t sockfd, sockaddr* addr, socklen_t* addrlen) -> Api::SysCallSocketResult { + return os_sys_calls_actual_.accept(sockfd, addr, addrlen); + })); + EXPECT_CALL(os_sys_calls, supportsGetifaddrs()) + .Times(AnyNumber()) + .WillRepeatedly( + Invoke([this]() -> bool { return os_sys_calls_actual_.supportsGetifaddrs(); })); + EXPECT_CALL(os_sys_calls, getifaddrs(_)) + .Times(AnyNumber()) + .WillRepeatedly(Invoke([this](Api::InterfaceAddressVector& vector) -> Api::SysCallIntResult { + return os_sys_calls_actual_.getifaddrs(vector); + })); + + connect(false); + write("PROXY TCP4 1.2.3.4 253.253.253.253 65535 1234\r\nmore data"); + + expectProxyProtoError(); +} +#endif + class WildcardProxyProtocolTest : public testing::TestWithParam, public Network::ListenerConfig, public Network::FilterChainManager, diff --git a/test/extensions/filters/listener/tls_inspector/BUILD b/test/extensions/filters/listener/tls_inspector/BUILD index 5978f030d237d..426c30a755c21 100644 --- a/test/extensions/filters/listener/tls_inspector/BUILD +++ b/test/extensions/filters/listener/tls_inspector/BUILD @@ -22,6 +22,8 @@ envoy_cc_test( deps = [ ":tls_utility_lib", "//source/common/http:utility_lib", + "//source/common/network:default_socket_interface_lib", + "//source/common/network:listener_filter_buffer_lib", "//source/extensions/filters/listener/tls_inspector:config", "//source/extensions/filters/listener/tls_inspector:tls_inspector_lib", "//test/mocks/api:api_mocks", @@ -62,6 +64,7 @@ envoy_extension_cc_benchmark_binary( ":tls_utility_lib", "//source/common/http:utility_lib", "//source/common/network:listen_socket_lib", + "//source/common/network:listener_filter_buffer_lib", "//source/extensions/filters/listener/tls_inspector:tls_inspector_lib", "//test/mocks/api:api_mocks", "//test/mocks/network:network_mocks", diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc index 66d76337e089b..f1d1e6d89537f 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_benchmark.cc @@ -4,6 +4,7 @@ #include "source/common/http/utility.h" #include "source/common/network/io_socket_handle_impl.h" #include "source/common/network/listen_socket_impl.h" +#include "source/common/network/listener_filter_buffer_impl.h" #include "source/extensions/filters/listener/tls_inspector/tls_inspector.h" #include "test/extensions/filters/listener/tls_inspector/tls_utility.h" @@ -25,14 +26,10 @@ namespace TlsInspector { class FastMockListenerFilterCallbacks : public Network::MockListenerFilterCallbacks { public: - FastMockListenerFilterCallbacks(Network::ConnectionSocket& socket, Event::Dispatcher& dispatcher) - : socket_(socket), dispatcher_(dispatcher) {} + FastMockListenerFilterCallbacks(Network::ConnectionSocket& socket) : socket_(socket) {} Network::ConnectionSocket& socket() override { return socket_; } - Event::Dispatcher& dispatcher() override { return dispatcher_; } - void continueFilterChain(bool success) override { RELEASE_ASSERT(success, ""); } Network::ConnectionSocket& socket_; - Event::Dispatcher& dispatcher_; }; // Don't inherit from the mock implementation at all, because this is instantiated @@ -79,12 +76,18 @@ static void BM_TlsInspector(benchmark::State& state) { Network::IoHandlePtr io_handle = std::make_unique(); Network::ConnectionSocketImpl socket(std::move(io_handle), nullptr, nullptr); NiceMock dispatcher; - FastMockListenerFilterCallbacks cb(socket, dispatcher); + FastMockListenerFilterCallbacks cb(socket); + Network::ListenerFilterBufferImpl buffer( + socket.ioHandle(), dispatcher, [](bool) {}, [](Network::ListenerFilterBuffer&) {}, + cfg->maxClientHelloSize()); + dispatcher.file_event_callback_(Event::FileReadyType::Read); for (auto _ : state) { + UNREFERENCED_PARAMETER(_); Filter filter(cfg); filter.onAccept(cb); - RELEASE_ASSERT(dispatcher.file_event_callback_ == nullptr, ""); + auto filter_state = filter.onData(buffer); + RELEASE_ASSERT(filter_state == Network::FilterStatus::Continue, ""); RELEASE_ASSERT(socket.detectedTransportProtocol() == "tls", ""); RELEASE_ASSERT(socket.requestedServerName() == "example.com", ""); RELEASE_ASSERT(socket.requestedApplicationProtocols().size() == 2 && diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.cc index d1be04b6f8bcc..2660d803c55b7 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.cc @@ -30,7 +30,7 @@ DEFINE_PROTO_FUZZER( auto filter = std::make_unique(std::move(cfg)); - ListenerFilterFuzzer fuzzer; + ListenerFilterWithDataFuzzer fuzzer; fuzzer.fuzz(std::move(filter), input.fuzzed()); } diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.proto b/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.proto index 37c9423dac38d..db44ae2473be6 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.proto +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_fuzz_test.proto @@ -10,6 +10,6 @@ message TlsInspectorTestCase { envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector config = 1 [(validate.rules).message.required = true]; uint32 max_size = 2 [(validate.rules).uint32.lte = 65536]; - test.extensions.filters.listener.FilterFuzzTestCase fuzzed = 3 + test.extensions.filters.listener.FilterFuzzWithDataTestCase fuzzed = 3 [(validate.rules).message.required = true]; } diff --git a/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc b/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc index 06d49dae2e214..6fe158e7ed07f 100644 --- a/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc +++ b/test/extensions/filters/listener/tls_inspector/tls_inspector_test.cc @@ -1,6 +1,7 @@ #include "source/common/common/hex.h" #include "source/common/http/utility.h" #include "source/common/network/io_socket_handle_impl.h" +#include "source/common/network/listener_filter_buffer_impl.h" #include "source/extensions/filters/listener/tls_inspector/tls_inspector.h" #include "test/extensions/filters/listener/tls_inspector/tls_utility.h" @@ -36,30 +37,45 @@ class TlsInspectorTest : public testing::TestWithParam( store_, envoy::extensions::filters::listener::tls_inspector::v3::TlsInspector())), - io_handle_(std::make_unique(42)) {} - ~TlsInspectorTest() override { io_handle_->close(); } + io_handle_( + Network::SocketInterfaceImpl::makePlatformSpecificSocket(42, false, absl::nullopt)) {} void init() { filter_ = std::make_unique(cfg_); EXPECT_CALL(cb_, socket()).WillRepeatedly(ReturnRef(socket_)); - EXPECT_CALL(cb_, dispatcher()).WillRepeatedly(ReturnRef(dispatcher_)); EXPECT_CALL(socket_, ioHandle()).WillRepeatedly(ReturnRef(*io_handle_)); - - // Prepare the first recv attempt during - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([](os_fd_t fd, void* buffer, size_t length, int flag) -> Api::SysCallSizeResult { - ENVOY_LOG_MISC(debug, "In mock syscall recv {} {} {} {}", fd, buffer, length, flag); - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN}; - })); EXPECT_CALL(dispatcher_, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) .WillOnce( DoAll(SaveArg<1>(&file_event_callback_), ReturnNew>())); + buffer_ = std::make_unique( + *io_handle_, dispatcher_, [](bool) {}, [](Network::ListenerFilterBuffer&) {}, + cfg_->maxClientHelloSize()); filter_->onAccept(cb_); } + void mockSysCallForPeek(std::vector& client_hello) { +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [&client_hello](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= client_hello.size()); + memcpy(iov->iov_base, client_hello.data(), client_hello.size()); + return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); +#else + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce(Invoke( + [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= client_hello.size()); + memcpy(buffer, client_hello.data(), client_hello.size()); + return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; + })); +#endif + } + void testJA3(const std::string& fingerprint, bool expect_server_name = true, const std::string& hash = {}); @@ -73,6 +89,7 @@ class TlsInspectorTest : public testing::TestWithParam dispatcher_; Event::FileReadyCb file_event_callback_; Network::IoHandlePtr io_handle_; + std::unique_ptr buffer_; }; INSTANTIATE_TEST_SUITE_P(TlsProtocolVersions, TlsInspectorTest, @@ -91,45 +108,20 @@ TEST_P(TlsInspectorTest, MaxClientHelloSize) { "max_client_hello_size of 65537 is greater than maximum of 65536."); } -// Test that the filter detects Closed events and terminates. -TEST_P(TlsInspectorTest, ConnectionClosed) { - init(); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Return(Api::SysCallSizeResult{0, 0})); - EXPECT_CALL(cb_, continueFilterChain(false)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().connection_closed_.value()); -} - -// Test that the filter detects detects read errors. -TEST_P(TlsInspectorTest, ReadError) { - init(); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)).WillOnce(InvokeWithoutArgs([]() { - return Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_NOT_SUP}; - })); - EXPECT_CALL(cb_, continueFilterChain(false)); - file_event_callback_(Event::FileReadyType::Read); - EXPECT_EQ(1, cfg_->stats().read_error_.value()); -} - // Test that a ClientHello with an SNI value causes the correct name notification. TEST_P(TlsInspectorTest, SniRegistered) { init(); const std::string servername("example.com"); std::vector client_hello = Tls::Test::generateClientHello( std::get<0>(GetParam()), std::get<1>(GetParam()), servername, ""); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); + mockSysCallForPeek(client_hello); EXPECT_CALL(socket_, setRequestedServerName(Eq(servername))); EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_CALL(cb_, continueFilterChain(true)); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); EXPECT_EQ(1, cfg_->stats().tls_found_.value()); EXPECT_EQ(1, cfg_->stats().sni_found_.value()); EXPECT_EQ(1, cfg_->stats().alpn_not_found_.value()); @@ -142,18 +134,14 @@ TEST_P(TlsInspectorTest, AlpnRegistered) { Http::Utility::AlpnNames::get().Http11}; std::vector client_hello = Tls::Test::generateClientHello( std::get<0>(GetParam()), std::get<1>(GetParam()), "", "\x02h2\x08http/1.1"); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); + mockSysCallForPeek(client_hello); EXPECT_CALL(socket_, setRequestedServerName(_)).Times(0); EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_CALL(cb_, continueFilterChain(true)); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); EXPECT_EQ(1, cfg_->stats().tls_found_.value()); EXPECT_EQ(1, cfg_->stats().sni_not_found_.value()); EXPECT_EQ(1, cfg_->stats().alpn_found_.value()); @@ -166,6 +154,23 @@ TEST_P(TlsInspectorTest, MultipleReads) { const std::string servername("example.com"); std::vector client_hello = Tls::Test::generateClientHello( std::get<0>(GetParam()), std::get<1>(GetParam()), servername, "\x02h2"); +#ifdef WIN32 + { + InSequence s; + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + for (size_t i = 0; i < client_hello.size(); i++) { + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke([&client_hello, i](os_fd_t fd, const iovec* iov, + int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= client_hello.size()); + memcpy(iov->iov_base, client_hello.data() + i, 1); + return Api::SysCallSizeResult{ssize_t(1), 0}; + })) + .WillOnce(Return(Api::SysCallSizeResult{ssize_t(-1), SOCKET_ERROR_AGAIN})); + } + } +#else { InSequence s; EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) @@ -182,16 +187,18 @@ TEST_P(TlsInspectorTest, MultipleReads) { })); } } - +#endif bool got_continue = false; EXPECT_CALL(socket_, setRequestedServerName(Eq(servername))); EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_CALL(cb_, continueFilterChain(true)).WillOnce(InvokeWithoutArgs([&got_continue]() { - got_continue = true; - })); while (!got_continue) { + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + if (state == Network::FilterStatus::Continue) { + got_continue = true; + } } EXPECT_EQ(1, cfg_->stats().tls_found_.value()); EXPECT_EQ(1, cfg_->stats().sni_found_.value()); @@ -203,18 +210,14 @@ TEST_P(TlsInspectorTest, NoExtensions) { init(); std::vector client_hello = Tls::Test::generateClientHello(std::get<0>(GetParam()), std::get<1>(GetParam()), "", ""); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); + mockSysCallForPeek(client_hello); EXPECT_CALL(socket_, setRequestedServerName(_)).Times(0); EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_CALL(cb_, continueFilterChain(true)); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); EXPECT_EQ(1, cfg_->stats().tls_found_.value()); EXPECT_EQ(1, cfg_->stats().sni_not_found_.value()); EXPECT_EQ(1, cfg_->stats().alpn_not_found_.value()); @@ -229,7 +232,29 @@ TEST_P(TlsInspectorTest, ClientHelloTooBig) { std::vector client_hello = Tls::Test::generateClientHello( std::get<0>(GetParam()), std::get<1>(GetParam()), "example.com", ""); ASSERT(client_hello.size() > max_size); - init(); + + filter_ = std::make_unique(cfg_); + + EXPECT_CALL(cb_, socket()).WillRepeatedly(ReturnRef(socket_)); + EXPECT_CALL(socket_, ioHandle()).WillRepeatedly(ReturnRef(*io_handle_)); + EXPECT_CALL(dispatcher_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce( + DoAll(SaveArg<1>(&file_event_callback_), ReturnNew>())); + buffer_ = std::make_unique( + *io_handle_, dispatcher_, [](bool) {}, [](Network::ListenerFilterBuffer&) {}, + cfg_->maxClientHelloSize()); + + filter_->onAccept(cb_); +#ifdef WIN32 + EXPECT_CALL(os_sys_calls_, readv(_, _, _)) + .WillOnce(Invoke( + [=, &client_hello](os_fd_t fd, const iovec* iov, int iovcnt) -> Api::SysCallSizeResult { + ASSERT(iov->iov_len >= max_size); + memcpy(iov->iov_base, client_hello.data(), max_size); + return Api::SysCallSizeResult{ssize_t(max_size), 0}; + })); +#else EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) .WillOnce(Invoke( [=, &client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { @@ -237,8 +262,11 @@ TEST_P(TlsInspectorTest, ClientHelloTooBig) { memcpy(buffer, client_hello.data(), length); return Api::SysCallSizeResult{ssize_t(length), 0}; })); - EXPECT_CALL(cb_, continueFilterChain(false)); +#endif + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::StopIteration, state); EXPECT_EQ(1, cfg_->stats().client_hello_too_large_.value()); } @@ -250,19 +278,15 @@ TEST_P(TlsInspectorTest, ConnectionFingerprint) { std::vector client_hello = Tls::Test::generateClientHello(std::get<0>(GetParam()), std::get<1>(GetParam()), "", ""); init(); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); + mockSysCallForPeek(client_hello); EXPECT_CALL(socket_, setJA3Hash(_)); EXPECT_CALL(socket_, setRequestedServerName(_)).Times(0); EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_CALL(cb_, continueFilterChain(true)); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); } void TlsInspectorTest::testJA3(const std::string& fingerprint, bool expect_server_name, @@ -272,13 +296,7 @@ void TlsInspectorTest::testJA3(const std::string& fingerprint, bool expect_serve cfg_ = std::make_shared(store_, proto_config); std::vector client_hello = Tls::Test::generateClientHelloFromJA3Fingerprint(fingerprint); init(); - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke( - [&client_hello](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); + mockSysCallForPeek(client_hello); if (hash.empty()) { uint8_t buf[MD5_DIGEST_LENGTH]; MD5(reinterpret_cast(fingerprint.data()), fingerprint.size(), buf); @@ -290,9 +308,12 @@ void TlsInspectorTest::testJA3(const std::string& fingerprint, bool expect_serve EXPECT_CALL(socket_, setRequestedServerName(absl::string_view("www.envoyproxy.io"))); } EXPECT_CALL(socket_, setRequestedApplicationProtocols(_)).Times(0); - EXPECT_CALL(cb_, continueFilterChain(true)); + // EXPECT_CALL(cb_, continueFilterChain(true)); EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); } // Test that the filter sets the correct `JA3` hash. @@ -360,51 +381,14 @@ TEST_P(TlsInspectorTest, NotSsl) { // Use 100 bytes of zeroes. This is not valid as a ClientHello. data.resize(100); - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce( - Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { - ASSERT(length >= data.size()); - memcpy(buffer, data.data(), data.size()); - return Api::SysCallSizeResult{ssize_t(data.size()), 0}; - })); - EXPECT_CALL(cb_, continueFilterChain(true)); + mockSysCallForPeek(data); + // trigger the event to copy the client hello message into buffer file_event_callback_(Event::FileReadyType::Read); + auto state = filter_->onData(*buffer_); + EXPECT_EQ(Network::FilterStatus::Continue, state); EXPECT_EQ(1, cfg_->stats().tls_not_found_.value()); } -TEST_P(TlsInspectorTest, InlineReadSucceed) { - filter_ = std::make_unique(cfg_); - - EXPECT_CALL(cb_, socket()).WillRepeatedly(ReturnRef(socket_)); - EXPECT_CALL(cb_, dispatcher()).WillRepeatedly(ReturnRef(dispatcher_)); - EXPECT_CALL(socket_, ioHandle()).WillRepeatedly(ReturnRef(*io_handle_)); - const auto alpn_protos = std::vector{Http::Utility::AlpnNames::get().Http2}; - const std::string servername("example.com"); - std::vector client_hello = Tls::Test::generateClientHello( - std::get<0>(GetParam()), std::get<1>(GetParam()), servername, "\x02h2"); - - EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) - .WillOnce(Invoke([&client_hello](os_fd_t fd, void* buffer, size_t length, - int flag) -> Api::SysCallSizeResult { - ENVOY_LOG_MISC(trace, "In mock syscall recv {} {} {} {}", fd, buffer, length, flag); - ASSERT(length >= client_hello.size()); - memcpy(buffer, client_hello.data(), client_hello.size()); - return Api::SysCallSizeResult{ssize_t(client_hello.size()), 0}; - })); - - // No event is created if the inline recv parse the hello. - EXPECT_CALL(dispatcher_, - createFileEvent_(_, _, Event::PlatformDefaultTriggerType, - Event::FileReadyType::Read | Event::FileReadyType::Closed)) - .Times(0); - - EXPECT_CALL(socket_, setRequestedServerName(Eq(servername))); - EXPECT_CALL(socket_, setRequestedApplicationProtocols(alpn_protos)); - EXPECT_CALL(socket_, setDetectedTransportProtocol(absl::string_view("tls"))); - EXPECT_EQ(Network::FilterStatus::Continue, filter_->onAccept(cb_)); -} - } // namespace } // namespace TlsInspector } // namespace ListenerFilters diff --git a/test/integration/filters/address_restore_listener_filter.cc b/test/integration/filters/address_restore_listener_filter.cc index ed1605924d8dc..4b7de0797aa1e 100644 --- a/test/integration/filters/address_restore_listener_filter.cc +++ b/test/integration/filters/address_restore_listener_filter.cc @@ -31,6 +31,12 @@ class FakeOriginalDstListenerFilter : public Network::ListenerFilter { socket.connectionInfoProvider().localAddressRestored()); return Network::FilterStatus::Continue; } + + size_t maxReadBytes() const override { return 0; } + + Network::FilterStatus onData(Network::ListenerFilterBuffer&) override { + return Network::FilterStatus::Continue; + } }; class FakeOriginalDstListenerFilterConfigFactory diff --git a/test/mocks/network/mocks.cc b/test/mocks/network/mocks.cc index 04958420157bc..f27c1738ed368 100644 --- a/test/mocks/network/mocks.cc +++ b/test/mocks/network/mocks.cc @@ -109,7 +109,6 @@ MockUdpListenerCallbacks::~MockUdpListenerCallbacks() = default; MockDrainDecision::MockDrainDecision() = default; MockDrainDecision::~MockDrainDecision() = default; -MockListenerFilter::MockListenerFilter() = default; MockListenerFilter::~MockListenerFilter() { destroy_(); } MockListenerFilterCallbacks::MockListenerFilterCallbacks() { diff --git a/test/mocks/network/mocks.h b/test/mocks/network/mocks.h index 2acb8562d3e1b..2360c129f7f4a 100644 --- a/test/mocks/network/mocks.h +++ b/test/mocks/network/mocks.h @@ -196,11 +196,16 @@ class MockDrainDecision : public DrainDecision { class MockListenerFilter : public ListenerFilter { public: - MockListenerFilter(); + MockListenerFilter(size_t max_read_bytes = 0) : listener_filter_max_read_bytes_(max_read_bytes) {} ~MockListenerFilter() override; + size_t maxReadBytes() const override { return listener_filter_max_read_bytes_; } + MOCK_METHOD(void, destroy_, ()); MOCK_METHOD(Network::FilterStatus, onAccept, (ListenerFilterCallbacks&)); + MOCK_METHOD(Network::FilterStatus, onData, (Network::ListenerFilterBuffer&)); + + size_t listener_filter_max_read_bytes_{0}; }; class MockListenerFilterManager : public ListenerFilterManager { diff --git a/test/per_file_coverage.sh b/test/per_file_coverage.sh index 5bc6559ab9da1..c8b07396a478e 100755 --- a/test/per_file_coverage.sh +++ b/test/per_file_coverage.sh @@ -49,7 +49,8 @@ declare -a KNOWN_LOW_COVERAGE=( "source/extensions/filters/http/wasm:95.8" "source/extensions/filters/listener:95.9" "source/extensions/filters/listener/http_inspector:95.8" -"source/extensions/filters/listener/original_dst:93.3" +"source/extensions/filters/listener/original_dst:82.4" +"source/extensions/filters/listener/original_src:92.1" "source/extensions/filters/listener/tls_inspector:92.3" "source/extensions/filters/network/common:96.0" "source/extensions/filters/network/common/redis:96.2" diff --git a/test/server/BUILD b/test/server/BUILD index e62d1f586941a..5b47d0b5cc391 100644 --- a/test/server/BUILD +++ b/test/server/BUILD @@ -88,6 +88,8 @@ envoy_cc_test( "//test/test_common:network_utility_lib", "//test/test_common:test_runtime_lib", "//test/test_common:threadsafe_singleton_injector_lib", + "@envoy_api//envoy/config/core/v3:pkg_cc_proto", + "@envoy_api//envoy/config/listener/v3:pkg_cc_proto", ], ) diff --git a/test/server/active_internal_listener_test.cc b/test/server/active_internal_listener_test.cc index a2e60f3bc8c89..b2e50e9374317 100644 --- a/test/server/active_internal_listener_test.cc +++ b/test/server/active_internal_listener_test.cc @@ -1,5 +1,6 @@ #include +#include "envoy/api/io_error.h" #include "envoy/network/filter.h" #include "envoy/network/listener.h" #include "envoy/stats/scope.h" @@ -17,6 +18,7 @@ #include "gtest/gtest.h" using testing::_; +using testing::ByMove; using testing::Invoke; using testing::NiceMock; using testing::Return; @@ -111,10 +113,14 @@ TEST_F(ActiveInternalListenerTest, AcceptSocketAndCreateListenerFilter) { TEST_F(ActiveInternalListenerTest, DestroyListenerClosesActiveSocket) { addListener(); expectFilterChainFactory(); - Network::MockListenerFilter* test_listener_filter = new Network::MockListenerFilter(); + Network::MockListenerFilter* test_listener_filter = new Network::MockListenerFilter(10); Network::MockConnectionSocket* accepted_socket = new NiceMock(); NiceMock io_handle; - EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)); + EXPECT_CALL(*accepted_socket, ioHandle()).WillRepeatedly(ReturnRef(io_handle)); + EXPECT_CALL(io_handle, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + 0, Api::IoErrorPtr(Network::IoSocketError::getIoSocketEagainInstance(), + Network::IoSocketError::deleteIoError))))); EXPECT_CALL(io_handle, isOpen()).WillOnce(Return(true)); EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) diff --git a/test/server/active_tcp_listener_test.cc b/test/server/active_tcp_listener_test.cc index 238ebe47bdca4..83d0e7b7c6b6b 100644 --- a/test/server/active_tcp_listener_test.cc +++ b/test/server/active_tcp_listener_test.cc @@ -20,10 +20,12 @@ #include "gtest/gtest.h" using testing::_; +using testing::ByMove; using testing::Invoke; using testing::NiceMock; using testing::Return; using testing::ReturnRef; +using testing::SaveArg; namespace Envoy { namespace Server { @@ -48,72 +50,363 @@ class ActiveTcpListenerTest : public testing::Test, protected Logger::Loggable>(); } + void initialize() { + EXPECT_CALL(listener_config_, connectionBalancer()).WillRepeatedly(ReturnRef(balancer_)); + EXPECT_CALL(listener_config_, listenerScope).Times(testing::AnyNumber()); + EXPECT_CALL(listener_config_, listenerFiltersTimeout()); + EXPECT_CALL(listener_config_, continueOnListenerFiltersTimeout()); + EXPECT_CALL(listener_config_, filterChainManager()).WillRepeatedly(ReturnRef(manager_)); + EXPECT_CALL(listener_config_, openConnections()).WillRepeatedly(ReturnRef(resource_limit_)); + EXPECT_CALL(listener_config_, filterChainFactory()) + .WillRepeatedly(ReturnRef(filter_chain_factory_)); + } + + void initializeWithInspectFilter() { + initialize(); + filter_ = new NiceMock(inspect_size_); + EXPECT_CALL(*filter_, destroy_()); + EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{filter_}); + return true; + })); + generic_listener_ = std::make_unique>(); + EXPECT_CALL(*generic_listener_, onDestroy()); + generic_active_listener_ = std::make_unique( + conn_handler_, std::move(generic_listener_), listener_config_, runtime_); + generic_active_listener_->incNumConnections(); + generic_accepted_socket_ = std::make_unique>(); + EXPECT_CALL(*generic_accepted_socket_, ioHandle()).WillRepeatedly(ReturnRef(io_handle_)); + } + + void initializeWithFilter() { + initialize(); + filter_ = new NiceMock(); + EXPECT_CALL(*filter_, destroy_()); + EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{filter_}); + return true; + })); + generic_listener_ = std::make_unique>(); + EXPECT_CALL(*generic_listener_, onDestroy()); + generic_active_listener_ = std::make_unique( + conn_handler_, std::move(generic_listener_), listener_config_, runtime_); + generic_active_listener_->incNumConnections(); + generic_accepted_socket_ = std::make_unique>(); + EXPECT_CALL(*generic_accepted_socket_, ioHandle()).WillRepeatedly(ReturnRef(io_handle_)); + } + std::string listener_stat_prefix_{"listener_stat_prefix"}; std::shared_ptr socket_factory_{ std::make_shared()}; NiceMock dispatcher_{"test"}; BasicResourceLimitImpl resource_limit_; NiceMock conn_handler_; - Network::MockListener* generic_listener_; + std::unique_ptr> generic_listener_; Network::MockListenerConfig listener_config_; NiceMock manager_; NiceMock filter_chain_factory_; std::shared_ptr filter_chain_; std::shared_ptr> listener_filter_matcher_; + NiceMock balancer_; + NiceMock* filter_; + size_t inspect_size_{128}; + std::unique_ptr generic_active_listener_; + NiceMock io_handle_; + std::unique_ptr> generic_accepted_socket_; NiceMock runtime_; }; -TEST_F(ActiveTcpListenerTest, PopulateSNIWhenActiveTcpSocketTimeout) { - NiceMock balancer; - EXPECT_CALL(listener_config_, connectionBalancer()).WillRepeatedly(ReturnRef(balancer)); - EXPECT_CALL(listener_config_, listenerScope).Times(testing::AnyNumber()); - EXPECT_CALL(listener_config_, listenerFiltersTimeout()) - .WillOnce(Return(std::chrono::milliseconds(1000))); - EXPECT_CALL(listener_config_, continueOnListenerFiltersTimeout()); - EXPECT_CALL(listener_config_, openConnections()).WillRepeatedly(ReturnRef(resource_limit_)); +/** + * Execute peek data two times, then filter return successful. + */ +TEST_F(ActiveTcpListenerTest, ListenerFilterWithInspectData) { + initializeWithInspectFilter(); + + // The filter stop the filter iteration and waiting for the data. + EXPECT_CALL(*filter_, onAccept(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + + Event::FileReadyCb file_event_callback; + // ensure the listener filter buffer will register the file event. + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + inspect_size_ / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // the filter is looking for more data. + EXPECT_CALL(*filter_, onData(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + generic_active_listener_->onAcceptWorker(std::move(generic_accepted_socket_), false, true); + + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove( + Api::IoCallUint64Result(inspect_size_, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // the filter get enough data, then return Network::FilterStatus::Continue + EXPECT_CALL(*filter_, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(nullptr)); + EXPECT_CALL(io_handle_, resetFileEvents()); + file_event_callback(Event::FileReadyType::Read); +} + +/** + * The event triggered data peek failed. + */ +TEST_F(ActiveTcpListenerTest, ListenerFilterWithInspectDataFailedWithPeek) { + initializeWithInspectFilter(); + + // The filter stop the filter iteration and waiting for the data. + EXPECT_CALL(*filter_, onAccept(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + Event::FileReadyCb file_event_callback; + // ensure the listener filter buffer will register the file event. + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + inspect_size_ / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + EXPECT_CALL(io_handle_, close) + .WillOnce(Return( + ByMove(Api::IoCallUint64Result(0, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // the filter is looking for more data. + EXPECT_CALL(*filter_, onData(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + // calling the onAcceptWorker() to create the ActiveTcpSocket. + generic_active_listener_->onAcceptWorker(std::move(generic_accepted_socket_), false, true); + + // peek data failed. + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove( + Api::IoCallUint64Result(0, Api::IoErrorPtr(new Network::IoSocketError(SOCKET_ERROR_INTR), + Network::IoSocketError::deleteIoError))))); + + file_event_callback(Event::FileReadyType::Read); + EXPECT_EQ(generic_active_listener_->stats_.downstream_listener_filter_error_.value(), 1); +} + +/** + * Multiple filters with different `MaxReadBytes()` value. + */ +TEST_F(ActiveTcpListenerTest, ListenerFilterWithInspectDataMultipleFilters) { + initialize(); + + auto inspect_size1 = 128; + auto* inspect_data_filter1 = new NiceMock(inspect_size1); + EXPECT_CALL(*inspect_data_filter1, destroy_()); + + auto inspect_size2 = 512; + auto* inspect_data_filter2 = new NiceMock(inspect_size2); + EXPECT_CALL(*inspect_data_filter2, destroy_()); + + auto inspect_size3 = 256; + auto* inspect_data_filter3 = new NiceMock(inspect_size3); + EXPECT_CALL(*inspect_data_filter3, destroy_()); + + auto* no_inspect_data_filter = new NiceMock(); + EXPECT_CALL(*no_inspect_data_filter, destroy_()); + + EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) + .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter1}); + // Expect the `onData()` callback won't be called for this filter. + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{no_inspect_data_filter}); + // Expect the ListenerFilterBuffer's capacity will be increased for this filter. + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter2}); + // Expect the ListenerFilterBuffer's capacity won't be decreased. + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter3}); + return true; + })); auto listener = std::make_unique>(); EXPECT_CALL(*listener, onDestroy()); + auto active_listener = std::make_unique(conn_handler_, std::move(listener), + listener_config_, runtime_); + auto accepted_socket = std::make_unique>(); - auto* test_filter = new NiceMock(); - EXPECT_CALL(*test_filter, destroy_()); - EXPECT_CALL(listener_config_, filterChainFactory()) - .WillRepeatedly(ReturnRef(filter_chain_factory_)); + EXPECT_CALL(*accepted_socket, ioHandle()).WillRepeatedly(ReturnRef(io_handle_)); + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + Event::FileReadyCb file_event_callback; + + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + EXPECT_CALL(io_handle_, recv) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(128, size); + return Api::IoCallUint64Result(128, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(512, size); + return Api::IoCallUint64Result(512, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(512, size); + return Api::IoCallUint64Result(512, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + + EXPECT_CALL(*inspect_data_filter1, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter1, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + EXPECT_CALL(*no_inspect_data_filter, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(nullptr)); + + EXPECT_CALL(*inspect_data_filter2, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter2, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + EXPECT_CALL(*inspect_data_filter3, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter3, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + active_listener->incNumConnections(); + // Calling the onAcceptWorker() to create the ActiveTcpSocket. + active_listener->onAcceptWorker(std::move(accepted_socket), false, true); +} + +/** + * Similar with above test, but with different filters order. + */ +TEST_F(ActiveTcpListenerTest, ListenerFilterWithInspectDataMultipleFilters2) { + initialize(); + + auto inspect_size1 = 128; + auto* inspect_data_filter1 = new NiceMock(inspect_size1); + EXPECT_CALL(*inspect_data_filter1, destroy_()); + + auto inspect_size2 = 512; + auto* inspect_data_filter2 = new NiceMock(inspect_size2); + EXPECT_CALL(*inspect_data_filter2, destroy_()); + + auto inspect_size3 = 256; + auto* inspect_data_filter3 = new NiceMock(inspect_size3); + EXPECT_CALL(*inspect_data_filter3, destroy_()); + + auto* no_inspect_data_filter = new NiceMock(); + EXPECT_CALL(*no_inspect_data_filter, destroy_()); - // add a filter to stop the filter iteration. EXPECT_CALL(filter_chain_factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { - manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{test_filter}); + // There will be no ListenerFilterBuffer created for first filter. + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{no_inspect_data_filter}); + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter1}); + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter2}); + manager.addAcceptFilter(nullptr, Network::ListenerFilterPtr{inspect_data_filter3}); return true; })); - EXPECT_CALL(*test_filter, onAccept(_)) - .WillOnce(Invoke([](Network::ListenerFilterCallbacks&) -> Network::FilterStatus { - return Network::FilterStatus::StopIteration; - })); + auto listener = std::make_unique>(); + EXPECT_CALL(*listener, onDestroy()); auto active_listener = std::make_unique(conn_handler_, std::move(listener), listener_config_, runtime_); - - absl::string_view server_name = "envoy.io"; auto accepted_socket = std::make_unique>(); - accepted_socket->connection_info_provider_->setRequestedServerName(server_name); - // fake the socket is open. - NiceMock io_handle; - EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)); - EXPECT_CALL(io_handle, isOpen()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, ioHandle()).WillRepeatedly(ReturnRef(io_handle_)); + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + Event::FileReadyCb file_event_callback; + + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + EXPECT_CALL(io_handle_, recv) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(128, size); + return Api::IoCallUint64Result(128, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(512, size); + return Api::IoCallUint64Result(512, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }) + .WillOnce([&](void*, size_t size, int) { + EXPECT_EQ(512, size); + return Api::IoCallUint64Result(512, Api::IoErrorPtr(nullptr, [](Api::IoError*) {})); + }); + + EXPECT_CALL(*inspect_data_filter1, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter1, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + EXPECT_CALL(*no_inspect_data_filter, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::Continue)); + EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(nullptr)); + + EXPECT_CALL(*inspect_data_filter2, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter2, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + EXPECT_CALL(*inspect_data_filter3, onAccept(_)) + .WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(*inspect_data_filter3, onData(_)).WillOnce(Return(Network::FilterStatus::Continue)); + + active_listener->incNumConnections(); + // Calling the onAcceptWorker() to create the ActiveTcpSocket. + active_listener->onAcceptWorker(std::move(accepted_socket), false, true); +} - EXPECT_CALL(balancer, pickTargetHandler(_)) - .WillOnce(testing::DoAll( - testing::WithArg<0>(Invoke([](auto& target) { target.incNumConnections(); })), - ReturnRef(*active_listener))); +/** + * Trigger the file closed event. + */ +TEST_F(ActiveTcpListenerTest, ListenerFilterWithClose) { + initializeWithInspectFilter(); + + // The filter stop the filter iteration and waiting for the data. + EXPECT_CALL(*filter_, onAccept(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + Event::FileReadyCb file_event_callback; + // ensure the listener filter buffer will register the file event. + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + inspect_size_ / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // the filter is looking for more data + EXPECT_CALL(*filter_, onData(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + generic_active_listener_->onAcceptWorker(std::move(generic_accepted_socket_), false, true); + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return( + ByMove(Api::IoCallUint64Result(0, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + EXPECT_CALL(io_handle_, close) + .WillOnce(Return( + ByMove(Api::IoCallUint64Result(0, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // emit the read event + file_event_callback(Event::FileReadyType::Read); + EXPECT_EQ(generic_active_listener_->stats_.downstream_listener_filter_remote_close_.value(), 1); +} + +TEST_F(ActiveTcpListenerTest, PopulateSNIWhenActiveTcpSocketTimeout2) { + initializeWithInspectFilter(); + + // The filter stop the filter iteration and waiting for the data. + EXPECT_CALL(*filter_, onAccept(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + EXPECT_CALL(io_handle_, isOpen()).WillRepeatedly(Return(true)); + + Event::FileReadyCb file_event_callback; + // ensure the listener filter buffer will register the file event. + EXPECT_CALL(io_handle_, + createFileEvent_(_, _, Event::PlatformDefaultTriggerType, Event::FileReadyType::Read)) + .WillOnce(SaveArg<1>(&file_event_callback)); + + EXPECT_CALL(io_handle_, recv) + .WillOnce(Return(ByMove(Api::IoCallUint64Result( + inspect_size_ / 2, Api::IoErrorPtr(nullptr, [](Api::IoError*) {}))))); + // the filter is looking for more data. + EXPECT_CALL(*filter_, onData(_)).WillOnce(Return(Network::FilterStatus::StopIteration)); + + absl::string_view server_name = "envoy.io"; + generic_accepted_socket_->connection_info_provider_->setRequestedServerName(server_name); + + generic_active_listener_->onAcceptWorker(std::move(generic_accepted_socket_), false, true); - // calling the onAcceptWorker() to create the ActiveTcpSocket. - active_listener->onAcceptWorker(std::move(accepted_socket), false, false); // get the ActiveTcpSocket pointer before unlink() removed from the link-list. - ActiveTcpSocket* tcp_socket = active_listener->sockets().front().get(); + ActiveTcpSocket* tcp_socket = generic_active_listener_->sockets().front().get(); // trigger the onTimeout event manually, since the timer is fake. - active_listener->sockets().front()->onTimeout(); + generic_active_listener_->sockets().front()->onTimeout(); EXPECT_EQ(server_name, tcp_socket->stream_info_->downstreamAddressProvider().requestedServerName()); diff --git a/test/server/connection_handler_test.cc b/test/server/connection_handler_test.cc index a1822db9d00d5..fe84ce5060816 100644 --- a/test/server/connection_handler_test.cc +++ b/test/server/connection_handler_test.cc @@ -4,6 +4,10 @@ #include #include +#include "envoy/api/io_error.h" +#include "envoy/config/core/v3/base.pb.h" +#include "envoy/config/listener/v3/udp_listener_config.pb.h" +#include "envoy/network/exception.h" #include "envoy/network/filter.h" #include "envoy/stats/scope.h" @@ -31,7 +35,6 @@ #include "gtest/gtest.h" using testing::_; -using testing::HasSubstr; using testing::InSequence; using testing::Invoke; using testing::MockFunction; @@ -1700,7 +1703,7 @@ TEST_F(ConnectionHandlerTest, ListenerFilterTimeout) { .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener, runtime_); - Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(512); EXPECT_CALL(factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { manager.addAcceptFilter(listener_filter_matcher_, Network::ListenerFilterPtr{test_filter}); @@ -1711,8 +1714,17 @@ TEST_F(ConnectionHandlerTest, ListenerFilterTimeout) { return Network::FilterStatus::StopIteration; })); Network::MockConnectionSocket* accepted_socket = new NiceMock(); - Network::IoSocketHandleImpl io_handle{42}; - EXPECT_CALL(*accepted_socket, ioHandle()).WillRepeatedly(ReturnRef(io_handle)); + Network::MockIoHandle io_handle; + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(io_handle, isOpen()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(io_handle, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read)); + EXPECT_CALL(io_handle, recv).WillOnce([&](void*, size_t, int) { + return Api::IoCallUint64Result( + 0, Api::IoErrorPtr(Network::IoSocketError::getIoSocketEagainInstance(), + Network::IoSocketError::deleteIoError)); + }); Event::MockTimer* timeout = new Event::MockTimer(&dispatcher_); EXPECT_CALL(*timeout, enableTimer(std::chrono::milliseconds(15000), _)); listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}); @@ -1738,6 +1750,7 @@ TEST_F(ConnectionHandlerTest, ListenerFilterTimeout) { TEST_F(ConnectionHandlerTest, ContinueOnListenerFilterTimeout) { InSequence s; + NiceMock os_sys_calls; Network::TcpListenerCallbacks* listener_callbacks; auto listener = new NiceMock(); TestListener* test_listener = @@ -1747,12 +1760,14 @@ TEST_F(ConnectionHandlerTest, ContinueOnListenerFilterTimeout) { .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener, runtime_); - Network::MockListenerFilter* test_filter = new NiceMock(); + Network::MockListenerFilter* test_filter = new NiceMock(128); EXPECT_CALL(factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { manager.addAcceptFilter(listener_filter_matcher_, Network::ListenerFilterPtr{test_filter}); return true; })); + + std::string data = "test"; EXPECT_CALL(*test_filter, onAccept(_)) .WillOnce(Invoke([&](Network::ListenerFilterCallbacks&) -> Network::FilterStatus { return Network::FilterStatus::StopIteration; @@ -1760,13 +1775,26 @@ TEST_F(ConnectionHandlerTest, ContinueOnListenerFilterTimeout) { Network::MockConnectionSocket* accepted_socket = new NiceMock(); Network::IoSocketHandleImpl io_handle{42}; EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce( + Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= data.size()); + memcpy(buffer, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })); + EXPECT_CALL(*test_filter, onData(_)) + .WillOnce(Invoke([&](Network::ListenerFilterBuffer&) -> Network::FilterStatus { + return Network::FilterStatus::StopIteration; + })); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + Event::MockTimer* timeout = new Event::MockTimer(&dispatcher_); EXPECT_CALL(*timeout, enableTimer(std::chrono::milliseconds(15000), _)); listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}); Stats::Gauge& downstream_pre_cx_active = stats_store_.gauge("downstream_pre_cx_active", Stats::Gauge::ImportMode::Accumulate); EXPECT_EQ(1UL, downstream_pre_cx_active.value()); - EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); EXPECT_CALL(*test_filter, destroy_()); // Barrier: test_filter must be destructed before findFilterChain EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(nullptr)); @@ -1793,6 +1821,7 @@ TEST_F(ConnectionHandlerTest, ContinueOnListenerFilterTimeout) { TEST_F(ConnectionHandlerTest, ListenerFilterTimeoutResetOnSuccess) { InSequence s; + NiceMock os_sys_calls; Network::TcpListenerCallbacks* listener_callbacks; auto listener = new NiceMock(); TestListener* test_listener = @@ -1801,26 +1830,40 @@ TEST_F(ConnectionHandlerTest, ListenerFilterTimeoutResetOnSuccess) { .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener, runtime_); - Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(123); EXPECT_CALL(factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { manager.addAcceptFilter(listener_filter_matcher_, Network::ListenerFilterPtr{test_filter}); return true; })); Network::ListenerFilterCallbacks* listener_filter_cb{}; + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + std::string data = "test"; EXPECT_CALL(*test_filter, onAccept(_)) .WillOnce(Invoke([&](Network::ListenerFilterCallbacks& cb) -> Network::FilterStatus { listener_filter_cb = &cb; return Network::FilterStatus::StopIteration; })); - Network::MockConnectionSocket* accepted_socket = new NiceMock(); Network::IoSocketHandleImpl io_handle{42}; EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(os_sys_calls_, recv(42, _, _, MSG_PEEK)) + .WillOnce( + Invoke([&data](os_fd_t, void* buffer, size_t length, int) -> Api::SysCallSizeResult { + ASSERT(length >= data.size()); + memcpy(buffer, data.data(), data.size()); + return Api::SysCallSizeResult{ssize_t(data.size()), 0}; + })); + EXPECT_CALL(*test_filter, onData(_)) + .WillOnce(Invoke([&](Network::ListenerFilterBuffer&) -> Network::FilterStatus { + return Network::FilterStatus::StopIteration; + })); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); Event::MockTimer* timeout = new Event::MockTimer(&dispatcher_); EXPECT_CALL(*timeout, enableTimer(std::chrono::milliseconds(15000), _)); listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}); - EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(*test_filter, destroy_()); EXPECT_CALL(manager_, findFilterChain(_)).WillOnce(Return(nullptr)); EXPECT_CALL(*access_log_, log(_, _, _, _)); @@ -1849,7 +1892,7 @@ TEST_F(ConnectionHandlerTest, ListenerFilterDisabledTimeout) { .WillRepeatedly(ReturnRef(local_address_)); handler_->addListener(absl::nullopt, *test_listener, runtime_); - Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(); + Network::MockListenerFilter* test_filter = new Network::MockListenerFilter(512); EXPECT_CALL(factory_, createListenerFilterChain(_)) .WillRepeatedly(Invoke([&](Network::ListenerFilterManager& manager) -> bool { manager.addAcceptFilter(listener_filter_matcher_, Network::ListenerFilterPtr{test_filter}); @@ -1859,10 +1902,20 @@ TEST_F(ConnectionHandlerTest, ListenerFilterDisabledTimeout) { .WillOnce(Invoke([&](Network::ListenerFilterCallbacks&) -> Network::FilterStatus { return Network::FilterStatus::StopIteration; })); - EXPECT_CALL(*access_log_, log(_, _, _, _)); + Network::MockIoHandle io_handle; + Network::MockConnectionSocket* accepted_socket = new NiceMock(); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(io_handle, isOpen()).WillOnce(Return(true)); + EXPECT_CALL(*accepted_socket, ioHandle()).WillOnce(ReturnRef(io_handle)).RetiresOnSaturation(); + EXPECT_CALL(io_handle, createFileEvent_(_, _, Event::PlatformDefaultTriggerType, + Event::FileReadyType::Read)); + EXPECT_CALL(io_handle, recv).WillOnce([&](void*, size_t, int) { + return Api::IoCallUint64Result( + 0, Api::IoErrorPtr(Network::IoSocketError::getIoSocketEagainInstance(), + Network::IoSocketError::deleteIoError)); + }); EXPECT_CALL(dispatcher_, createTimer_(_)).Times(0); EXPECT_CALL(*test_filter, destroy_()); - Network::MockConnectionSocket* accepted_socket = new NiceMock(); listener_callbacks->onAccept(Network::ConnectionSocketPtr{accepted_socket}); EXPECT_CALL(*listener, onDestroy()); diff --git a/tools/spelling/spelling_dictionary.txt b/tools/spelling/spelling_dictionary.txt index cc7bdbbad1816..9a2d46e4841c5 100644 --- a/tools/spelling/spelling_dictionary.txt +++ b/tools/spelling/spelling_dictionary.txt @@ -37,6 +37,7 @@ CWND DSR DSS EBADF +ENDIF ENOTCONN EPIPE HEXDIG @@ -767,6 +768,8 @@ intra ints invariance invoker +iov +iovcnt iovec iovecs ips @@ -1082,6 +1085,7 @@ retriable retriggers revalidated revalidation +rfield rmdir rocketmq rewriter