diff --git a/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc b/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc index 7af15c8d305b1..793bd96f32d34 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc +++ b/source/extensions/filters/network/postgres_proxy/postgres_decoder.cc @@ -13,13 +13,16 @@ namespace PostgresProxy { []() -> std::unique_ptr { return createMsgBodyReader<__VA_ARGS__>(); } #define NO_BODY BODY_FORMAT() +constexpr absl::string_view FRONTEND = "Frontend"; +constexpr absl::string_view BACKEND = "Backend"; + void DecoderImpl::initialize() { // Special handler for first message of the transaction. first_ = MessageProcessor{"Startup", BODY_FORMAT(Int32, Repeated), {&DecoderImpl::onStartup}}; // Frontend messages. - FE_messages_.direction_ = "Frontend"; + FE_messages_.direction_ = FRONTEND; // Setup handlers for known messages. absl::flat_hash_map& FE_known_msgs = FE_messages_.messages_; @@ -52,7 +55,7 @@ void DecoderImpl::initialize() { MessageProcessor{"Other", BODY_FORMAT(ByteN), {&DecoderImpl::incMessagesUnknown}}; // Backend messages. - BE_messages_.direction_ = "Backend"; + BE_messages_.direction_ = BACKEND; // Setup handlers for known messages. absl::flat_hash_map& BE_known_msgs = BE_messages_.messages_; @@ -176,88 +179,156 @@ void DecoderImpl::initialize() { }; } -Decoder::Result DecoderImpl::parseHeader(Buffer::Instance& data) { - ENVOY_LOG(trace, "postgres_proxy: parsing message, len {}", data.length()); +/* Main handler for incoming messages. Messages are dispatched based on the + current decoder's state. +*/ +Decoder::Result DecoderImpl::onData(Buffer::Instance& data, bool frontend) { + switch (state_) { + case State::InitState: + return onDataInit(data, frontend); + case State::OutOfSyncState: + case State::EncryptedState: + return onDataIgnore(data, frontend); + case State::InSyncState: + return onDataInSync(data, frontend); + default: + NOT_IMPLEMENTED_GCOVR_EXCL_LINE; + } +} - // The minimum size of the message sufficient for parsing is 5 bytes. - if (data.length() < 5) { +/* Handler for messages when decoder is in Init State. There are very few message types which + are allowed in this state. + If the initial message has the correct syntax and indicates that session should be in + clear-text, the decoder will move to InSyncState. If the initial message has the correct syntax + and indicates that session should be encrypted, the decoder stays in InitState, because the + initial message will be received again after transport socket negotiates SSL. If the message + syntax is incorrect, the decoder will move to OutOfSyncState, in which messages are not parsed. +*/ +Decoder::Result DecoderImpl::onDataInit(Buffer::Instance& data, bool) { + ASSERT(state_ == State::InitState); + + // In Init state the minimum size of the message sufficient for parsing is 4 bytes. + if (data.length() < 4) { // not enough data in the buffer. - return Decoder::NeedMoreData; + return Decoder::Result::NeedMoreData; + } + + // Validate the message before processing. + const MsgBodyReader& f = std::get<1>(first_); + const auto msgParser = f(); + // Run the validation. + message_len_ = data.peekBEInt(0); + if (message_len_ > MAX_STARTUP_PACKET_LENGTH) { + // Message does not conform to the expected format. Move to out-of-sync state. + data.drain(data.length()); + state_ = State::OutOfSyncState; + return Decoder::Result::ReadyForNext; } - if (!startup_) { - data.copyOut(0, 1, &command_); - ENVOY_LOG(trace, "postgres_proxy: command is {}", command_); + Message::ValidationResult validationResult = msgParser->validate(data, 4, message_len_ - 4); + + if (validationResult == Message::ValidationNeedMoreData) { + return Decoder::Result::NeedMoreData; } - // The 1 byte message type and message length should be in the buffer - // Check if the entire message has been read. - std::string message; - message_len_ = data.peekBEInt(startup_ ? 0 : 1); - if (data.length() < (message_len_ + (startup_ ? 0 : 1))) { - ENVOY_LOG(trace, "postgres_proxy: cannot parse message. Need {} bytes in buffer", - message_len_ + (startup_ ? 0 : 1)); - // Not enough data in the buffer. - return Decoder::NeedMoreData; + if (validationResult == Message::ValidationFailed) { + // Message does not conform to the expected format. Move to out-of-sync state. + data.drain(data.length()); + state_ = State::OutOfSyncState; + return Decoder::Result::ReadyForNext; } - if (startup_) { - uint32_t code = data.peekBEInt(4); - // Startup message with 1234 in the most significant 16 bits - // indicate request to encrypt. - if (code >= 0x04d20000) { - encrypted_ = true; - // Handler for SSLRequest (Int32(80877103) = 0x04d2162f) - // See details in https://www.postgresql.org/docs/current/protocol-message-formats.html. - if (code == 0x04d2162f) { - // Notify the filter that `SSLRequest` message was decoded. - // If the filter returns true, it means to pass the message upstream - // to the server. If it returns false it means, that filter will try - // to terminate SSL session and SSLRequest should not be passed to the - // server. - encrypted_ = callbacks_->onSSLRequest(); - } - - // Count it as recognized frontend message. - callbacks_->incMessagesFrontend(); - if (encrypted_) { - ENVOY_LOG(trace, "postgres_proxy: detected encrypted traffic."); - incSessionsEncrypted(); - startup_ = false; - } - data.drain(data.length()); - return encrypted_ ? Decoder::ReadyForNext : Decoder::Stopped; + Decoder::Result result = Decoder::Result::ReadyForNext; + uint32_t code = data.peekBEInt(4); + data.drain(4); + // Startup message with 1234 in the most significant 16 bits + // indicate request to encrypt. + if (code >= 0x04d20000) { + encrypted_ = true; + // Handler for SSLRequest (Int32(80877103) = 0x04d2162f) + // See details in https://www.postgresql.org/docs/current/protocol-message-formats.html. + if (code == 0x04d2162f) { + // Notify the filter that `SSLRequest` message was decoded. + // If the filter returns true, it means to pass the message upstream + // to the server. If it returns false it means, that filter will try + // to terminate SSL session and SSLRequest should not be passed to the + // server. + encrypted_ = callbacks_->onSSLRequest(); + } + + // Count it as recognized frontend message. + callbacks_->incMessagesFrontend(); + if (encrypted_) { + ENVOY_LOG(trace, "postgres_proxy: detected encrypted traffic."); + incSessionsEncrypted(); + state_ = State::EncryptedState; } else { - ENVOY_LOG(debug, "Detected version {}.{} of Postgres", code >> 16, code & 0x0000FFFF); + result = Decoder::Result::Stopped; + // Stay in InitState. After switch to SSL, another init packet will be sent. } + } else { + ENVOY_LOG(debug, "Detected version {}.{} of Postgres", code >> 16, code & 0x0000FFFF); + state_ = State::InSyncState; } - data.drain(startup_ ? 4 : 5); // Length plus optional 1st byte. - - ENVOY_LOG(trace, "postgres_proxy: msg parsed"); - return Decoder::ReadyForNext; + processMessageBody(data, FRONTEND, message_len_ - 4, first_, msgParser); + data.drain(message_len_); + return result; } -Decoder::Result DecoderImpl::onData(Buffer::Instance& data, bool frontend) { - // If encrypted, just drain the traffic. - if (encrypted_) { - ENVOY_LOG(trace, "postgres_proxy: ignoring {} bytes of encrypted data", data.length()); - data.drain(data.length()); - return Decoder::ReadyForNext; - } +/* + Method invokes actions associated with message type and generate debug logs. +*/ +void DecoderImpl::processMessageBody(Buffer::Instance& data, absl::string_view direction, + uint32_t length, MessageProcessor& msg, + const std::unique_ptr& parser) { + uint32_t bytes_to_read = length; - if (!frontend && startup_) { - data.drain(data.length()); - return Decoder::ReadyForNext; + std::vector& actions = std::get<2>(msg); + if (!actions.empty()) { + // Linearize the message for processing. + message_.assign(std::string(static_cast(data.linearize(bytes_to_read)), bytes_to_read)); + + // Invoke actions associated with the type of received message. + for (const auto& action : actions) { + action(this); + } + + // Drop the linearized message. + message_.erase(); } + ENVOY_LOG(debug, "({}) command = {} ({})", direction, command_, std::get<0>(msg)); + ENVOY_LOG(debug, "({}) length = {}", direction, message_len_); + ENVOY_LOG(debug, "({}) message = {}", direction, genDebugMessage(parser, data, bytes_to_read)); + + ENVOY_LOG(trace, "postgres_proxy: {} bytes remaining in buffer", data.length()); + + data.drain(length); +} + +/* + onDataInSync is called when decoder is on-track with decoding messages. + All previous messages has been decoded properly and decoder is able to find + message boundaries. +*/ +Decoder::Result DecoderImpl::onDataInSync(Buffer::Instance& data, bool frontend) { ENVOY_LOG(trace, "postgres_proxy: decoding {} bytes", data.length()); - const Decoder::Result result = parseHeader(data); - if (result != Decoder::ReadyForNext || encrypted_) { - return result; + ENVOY_LOG(trace, "postgres_proxy: parsing message, len {}", data.length()); + + // The minimum size of the message sufficient for parsing is 5 bytes. + if (data.length() < 5) { + // not enough data in the buffer. + return Decoder::Result::NeedMoreData; } + data.copyOut(0, 1, &command_); + ENVOY_LOG(trace, "postgres_proxy: command is {}", command_); + + // The 1 byte message type and message length should be in the buffer + // Find the message processor and validate the message syntax. + MsgGroup& msg_processor = std::ref(frontend ? FE_messages_ : BE_messages_); frontend ? callbacks_->incMessagesFrontend() : callbacks_->incMessagesBackend(); @@ -265,45 +336,55 @@ Decoder::Result DecoderImpl::onData(Buffer::Instance& data, bool frontend) { // If message is found, the processing will be updated. std::reference_wrapper msg = msg_processor.unknown_; - if (startup_) { - msg = std::ref(first_); - startup_ = false; - } else { - auto it = msg_processor.messages_.find(command_); - if (it != msg_processor.messages_.end()) { - msg = std::ref((*it).second); - } + auto it = msg_processor.messages_.find(command_); + if (it != msg_processor.messages_.end()) { + msg = std::ref((*it).second); } - // message_len_ specifies total message length including 4 bytes long - // "length" field. The length of message body is total length minus size - // of "length" field (4 bytes). - uint32_t bytes_to_read = message_len_ - 4; - - std::vector& actions = std::get<2>(msg.get()); - if (!actions.empty()) { - // Linearize the message for processing. - message_.assign(std::string(static_cast(data.linearize(bytes_to_read)), bytes_to_read)); - - // Invoke actions associated with the type of received message. - for (const auto& action : actions) { - action(this); - } + // Validate the message before processing. + const MsgBodyReader& f = std::get<1>(msg.get()); + message_len_ = data.peekBEInt(1); + const auto msgParser = f(); + // Run the validation. + // Because the message validation may return NeedMoreData error, data must stay intact (no + // draining) until the remaining data arrives and validator will run again. Validator therefore + // starts at offset 5 (1 byte message type and 4 bytes of length). This is in contrast to + // processing of the message, which assumes that message has been validated and starts at the + // beginning of the message. + Message::ValidationResult validationResult = msgParser->validate(data, 5, message_len_ - 4); + + if (validationResult == Message::ValidationNeedMoreData) { + ENVOY_LOG(trace, "postgres_proxy: cannot parse message. Not enough bytes in the buffer."); + return Decoder::Result::NeedMoreData; + } - // Drop the linearized message. - message_.erase(); + if (validationResult == Message::ValidationFailed) { + // Message does not conform to the expected format. Move to out-of-sync state. + data.drain(data.length()); + state_ = State::OutOfSyncState; + return Decoder::Result::ReadyForNext; } - ENVOY_LOG(debug, "({}) command = {} ({})", msg_processor.direction_, command_, - std::get<0>(msg.get())); - ENVOY_LOG(debug, "({}) length = {}", msg_processor.direction_, message_len_); - ENVOY_LOG(debug, "({}) message = {}", msg_processor.direction_, - genDebugMessage(msg, data, bytes_to_read)); + // Drain message code and length fields. + // Processing the message assumes that message starts at the beginning of the buffer. + data.drain(5); - data.drain(bytes_to_read); - ENVOY_LOG(trace, "postgres_proxy: {} bytes remaining in buffer", data.length()); + processMessageBody(data, msg_processor.direction_, message_len_ - 4, msg, msgParser); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; +} +/* + onDataIgnore method is called when the decoder does not inspect passing + messages. This happens when the decoder detected encrypted packets or + when the decoder could not validate passing messages and lost track of + messages boundaries. In order not to interpret received values as message + lengths and not to start buffering large amount of data, the decoder + enters OutOfSync state and starts ignoring passing messages. Once the + decoder enters OutOfSyncState it cannot leave that state. +*/ +Decoder::Result DecoderImpl::onDataIgnore(Buffer::Instance& data, bool) { + data.drain(data.length()); + return Decoder::Result::ReadyForNext; } // Method is called when C (CommandComplete) message has been @@ -423,16 +504,10 @@ void DecoderImpl::onStartup() { } // Method generates displayable format of currently processed message. -const std::string DecoderImpl::genDebugMessage(const MessageProcessor& msg, Buffer::Instance& data, - uint32_t message_len) { - const MsgBodyReader& f = std::get<1>(msg); - std::string message = "Unrecognized"; - if (f != nullptr) { - const auto msgParser = f(); - msgParser->read(data, message_len); - message = msgParser->toString(); - } - return message; +const std::string DecoderImpl::genDebugMessage(const std::unique_ptr& parser, + Buffer::Instance& data, uint32_t message_len) { + parser->read(data, message_len); + return parser->toString(); } } // namespace PostgresProxy diff --git a/source/extensions/filters/network/postgres_proxy/postgres_decoder.h b/source/extensions/filters/network/postgres_proxy/postgres_decoder.h index bed146c097a90..f62de0108f574 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_decoder.h +++ b/source/extensions/filters/network/postgres_proxy/postgres_decoder.h @@ -53,7 +53,7 @@ class Decoder { // The following values are returned by the decoder, when filter // passes bytes of data via onData method: - enum Result { + enum class Result { ReadyForNext, // Decoder processed previous message and is ready for the next message. NeedMoreData, // Decoder needs more data to reconstruct the message. Stopped // Received and processed message disrupts the current flow. Decoder stopped accepting @@ -84,12 +84,21 @@ class DecoderImpl : public Decoder, Logger::Loggable { std::string getMessage() { return message_; } - void setStartup(bool startup) { startup_ = startup; } void initialize(); bool encrypted() const { return encrypted_; } + enum class State { InitState, InSyncState, OutOfSyncState, EncryptedState }; + State state() const { return state_; } + void state(State state) { state_ = state; } + protected: + State state_{State::InitState}; + + Result onDataInit(Buffer::Instance& data, bool frontend); + Result onDataInSync(Buffer::Instance& data, bool frontend); + Result onDataIgnore(Buffer::Instance& data, bool frontend); + // MsgAction defines the Decoder's method which will be invoked // when a specific message has been decoded. using MsgAction = std::function; @@ -110,7 +119,7 @@ class DecoderImpl : public Decoder, Logger::Loggable { // Frontend and Backend messages. using MsgGroup = struct { // String describing direction (Frontend or Backend). - std::string direction_; + absl::string_view direction_; // Hash map indexed by messages' 1st byte points to handlers used for processing messages. absl::flat_hash_map messages_; // Handler used for processing messages not found in hash map. @@ -131,7 +140,8 @@ class DecoderImpl : public Decoder, Logger::Loggable { MsgAction unknown_; }; - Result parseHeader(Buffer::Instance& data); + void processMessageBody(Buffer::Instance& data, absl::string_view direction, uint32_t length, + MessageProcessor& msg, const std::unique_ptr& parser); void decode(Buffer::Instance& data); void decodeAuthentication(); void decodeBackendStatements(); @@ -149,18 +159,17 @@ class DecoderImpl : public Decoder, Logger::Loggable { // Helper method generating currently processed message in // displayable format. - const std::string genDebugMessage(const MessageProcessor& msg, Buffer::Instance& data, + const std::string genDebugMessage(const std::unique_ptr& parser, Buffer::Instance& data, uint32_t message_len); DecoderCallbacks* callbacks_{}; PostgresSession session_{}; // The following fields store result of message parsing. - char command_{}; + char command_{'-'}; std::string message_; uint32_t message_len_{}; - bool startup_{true}; // startup stage does not have 1st byte command bool encrypted_{false}; // tells if exchange is encrypted // Dispatchers for Backend (BE) and Frontend (FE) messages. @@ -178,6 +187,11 @@ class DecoderImpl : public Decoder, Logger::Loggable { MsgParserDict BE_errors_; MsgParserDict BE_notices_; + + // MAX_STARTUP_PACKET_LENGTH is defined in Postgres source code + // as maximum size of initial packet. + // https://github.com/postgres/postgres/search?q=MAX_STARTUP_PACKET_LENGTH&type=code + static constexpr uint64_t MAX_STARTUP_PACKET_LENGTH = 10000; }; } // namespace PostgresProxy diff --git a/source/extensions/filters/network/postgres_proxy/postgres_filter.cc b/source/extensions/filters/network/postgres_proxy/postgres_filter.cc index 71e3388718b17..0e8daac36fed4 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_filter.cc +++ b/source/extensions/filters/network/postgres_proxy/postgres_filter.cc @@ -231,11 +231,11 @@ Network::FilterStatus PostgresFilter::doDecode(Buffer::Instance& data, bool fron // that it cannot process data in the buffer. while (0 < data.length()) { switch (decoder_->onData(data, frontend)) { - case Decoder::NeedMoreData: + case Decoder::Result::NeedMoreData: return Network::FilterStatus::Continue; - case Decoder::ReadyForNext: + case Decoder::Result::ReadyForNext: continue; - case Decoder::Stopped: + case Decoder::Result::Stopped: return Network::FilterStatus::StopIteration; } } diff --git a/source/extensions/filters/network/postgres_proxy/postgres_message.cc b/source/extensions/filters/network/postgres_proxy/postgres_message.cc index 340092d489105..b8e4a3d5febe5 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_message.cc +++ b/source/extensions/filters/network/postgres_proxy/postgres_message.cc @@ -7,18 +7,16 @@ namespace PostgresProxy { // String type methods. bool String::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { - // First find the terminating zero. - const char zero = 0; - const ssize_t index = data.search(&zero, 1, pos); - if (index == -1) { - return false; - } + // read method uses values set by validate method. + // This avoids unnecessary repetition of scanning data looking for terminating zero. + ASSERT(pos == start_); + ASSERT(end_ >= start_); // Reserve that many bytes in the string. - const uint64_t size = index - pos; + const uint64_t size = end_ - start_; value_.resize(size); // Now copy from buffer to string. - data.copyOut(pos, index - pos, value_.data()); + data.copyOut(pos, size, value_.data()); pos += (size + 1); left -= (size + 1); @@ -27,6 +25,35 @@ bool String::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { std::string String::toString() const { return absl::StrCat("[", value_, "]"); } +Message::ValidationResult String::validate(const Buffer::Instance& data, + const uint64_t start_offset, uint64_t& pos, + uint64_t& left) { + // Try to find the terminating zero. + // If found, all is good. If not found, we may need more data. + const char zero = 0; + const ssize_t index = data.search(&zero, 1, pos); + if (index == -1) { + if (left <= (data.length() - pos)) { + // Message ended before finding terminating zero. + return Message::ValidationFailed; + } else { + return Message::ValidationNeedMoreData; + } + } + // Found, but after the message boundary. + const uint64_t size = index - pos; + if (size >= left) { + return Message::ValidationFailed; + } + + start_ = pos - start_offset; + end_ = start_ + size; + + pos += (size + 1); + left -= (size + 1); + return Message::ValidationOK; +} + // ByteN type methods. bool ByteN::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { if (left > (data.length() - pos)) { @@ -38,6 +65,19 @@ bool ByteN::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { left = 0; return true; } +// Since ByteN does not have a length field, it is not possible to verify +// its correctness. +Message::ValidationResult ByteN::validate(const Buffer::Instance& data, const uint64_t, + uint64_t& pos, uint64_t& left) { + if (left > (data.length() - pos)) { + return Message::ValidationNeedMoreData; + } + + pos += left; + left = 0; + + return Message::ValidationOK; +} std::string ByteN::toString() const { std::string out = "["; @@ -48,10 +88,7 @@ std::string ByteN::toString() const { // VarByteN type methods. bool VarByteN::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { - if ((left < sizeof(int32_t)) || ((data.length() - pos) < sizeof(int32_t))) { - return false; - } - len_ = data.peekBEInt(pos); + // len_ was set by validator, skip it. pos += sizeof(int32_t); left -= sizeof(int32_t); if (len_ < 1) { @@ -59,10 +96,7 @@ bool VarByteN::read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) value_.clear(); return true; } - if ((left < static_cast(len_)) || - ((data.length() - pos) < static_cast(len_))) { - return false; - } + value_.resize(len_); data.copyOut(pos, len_, value_.data()); pos += len_; @@ -78,6 +112,42 @@ std::string VarByteN::toString() const { return out; } +Message::ValidationResult VarByteN::validate(const Buffer::Instance& data, const uint64_t, + uint64_t& pos, uint64_t& left) { + if (left < sizeof(int32_t)) { + // Malformed message. + return Message::ValidationFailed; + } + + if ((data.length() - pos) < sizeof(int32_t)) { + return Message::ValidationNeedMoreData; + } + + // Read length of the VarByteN structure. + len_ = data.peekBEInt(pos); + if (static_cast(len_) > static_cast(left)) { + // VarByteN would extend past the current message boundaries. + // Lengths of message and individual fields do not match. + return Message::ValidationFailed; + } + + if (len_ < 1) { + // There is no payload if length is not positive. + pos += sizeof(int32_t); + left -= sizeof(int32_t); + return Message::ValidationOK; + } + + if ((data.length() - pos) < (len_ + sizeof(int32_t))) { + return Message::ValidationNeedMoreData; + } + + pos += (len_ + sizeof(int32_t)); + left -= (len_ + sizeof(int32_t)); + + return Message::ValidationOK; +} + } // namespace PostgresProxy } // namespace NetworkFilters } // namespace Extensions diff --git a/source/extensions/filters/network/postgres_proxy/postgres_message.h b/source/extensions/filters/network/postgres_proxy/postgres_message.h index 948167892b148..584662f0b7546 100644 --- a/source/extensions/filters/network/postgres_proxy/postgres_message.h +++ b/source/extensions/filters/network/postgres_proxy/postgres_message.h @@ -29,6 +29,30 @@ namespace PostgresProxy { * */ +// Interface to Postgres message class. +class Message { +public: + enum ValidationResult { ValidationFailed, ValidationOK, ValidationNeedMoreData }; + + virtual ~Message() = default; + + // read method should read only as many bytes from data + // buffer as it is indicated in message's length field. + // "length" parameter indicates how many bytes were indicated in Postgres message's + // length field. "data" buffer may contain more bytes than "length". + virtual bool read(const Buffer::Instance& data, const uint64_t length) PURE; + + virtual ValidationResult validate(const Buffer::Instance& data, const uint64_t, + const uint64_t) PURE; + + // toString method provides displayable representation of + // the Postgres message. + virtual std::string toString() const PURE; + +protected: + ValidationResult validation_result_{ValidationNeedMoreData}; +}; + // Template for integer types. // Size of integer types is fixed and depends on the type of integer. template class Int { @@ -46,15 +70,27 @@ template class Int { * for the current message. */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { - if ((data.length() - pos) < sizeof(T)) { - return false; - } value_ = data.peekBEInt(pos); pos += sizeof(T); left -= sizeof(T); return true; } + Message::ValidationResult validate(const Buffer::Instance& data, const uint64_t, uint64_t& pos, + uint64_t& left) { + if (left < sizeof(T)) { + return Message::ValidationFailed; + } + + if ((data.length() - pos) < sizeof(T)) { + return Message::ValidationNeedMoreData; + } + + pos += sizeof(T); + left -= sizeof(T); + return Message::ValidationOK; + } + std::string toString() const { return fmt::format("[{}]", value_); } T get() const { return value_; } @@ -78,8 +114,13 @@ class String { */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left); std::string toString() const; + Message::ValidationResult validate(const Buffer::Instance&, const uint64_t start_offset, + uint64_t&, uint64_t&); private: + // start_ and end_ are set by validate method. + uint64_t start_; + uint64_t end_; std::string value_; }; @@ -92,6 +133,7 @@ class ByteN { */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left); std::string toString() const; + Message::ValidationResult validate(const Buffer::Instance&, const uint64_t, uint64_t&, uint64_t&); private: std::vector value_; @@ -115,6 +157,7 @@ class VarByteN { */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left); std::string toString() const; + Message::ValidationResult validate(const Buffer::Instance&, const uint64_t, uint64_t&, uint64_t&); private: int32_t len_; @@ -128,22 +171,11 @@ template class Array { * See above for parameter and return value description. */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { - // First read the 16 bits value which indicates how many - // elements there are in the array. - if (((data.length() - pos) < sizeof(uint16_t)) || (left < sizeof(uint16_t))) { - return false; - } - const uint16_t num = data.peekBEInt(pos); + // Skip reading the size of array. The validator did it. pos += sizeof(uint16_t); left -= sizeof(uint16_t); - if (num != 0) { - for (uint16_t i = 0; i < num; i++) { - auto item = std::make_unique(); - if (!item->read(data, pos, left)) { - return false; - } - value_.push_back(std::move(item)); - } + for (uint16_t i = 0; i < size_; i++) { + value_[i]->read(data, pos, left); } return true; } @@ -161,8 +193,41 @@ template class Array { return out; } + Message::ValidationResult validate(const Buffer::Instance& data, const uint64_t start_offset, + uint64_t& pos, uint64_t& left) { + // First read the 16 bits value which indicates how many + // elements there are in the array. + if (left < sizeof(uint16_t)) { + return Message::ValidationFailed; + } + + if ((data.length() - pos) < sizeof(uint16_t)) { + return Message::ValidationNeedMoreData; + } + + size_ = data.peekBEInt(pos); + uint64_t orig_pos = pos; + uint64_t orig_left = left; + pos += sizeof(uint16_t); + left -= sizeof(uint16_t); + if (size_ != 0) { + for (uint16_t i = 0; i < size_; i++) { + auto item = std::make_unique(); + Message::ValidationResult result = item->validate(data, start_offset, pos, left); + if (Message::ValidationOK != result) { + pos = orig_pos; + left = orig_left; + value_.clear(); + return result; + } + value_.push_back(std::move(item)); + } + } + return Message::ValidationOK; + } private: + uint16_t size_; std::vector> value_; }; @@ -175,16 +240,10 @@ template class Repeated { * See above for parameter and return value description. */ bool read(const Buffer::Instance& data, uint64_t& pos, uint64_t& left) { - if ((data.length() - pos) < left) { - return false; - } - // Read until nothing is left. - while (left != 0) { - auto item = std::make_unique(); - if (!item->read(data, pos, left)) { + for (size_t i = 0; i < value_.size(); i++) { + if (!value_[i]->read(data, pos, left)) { return false; } - value_.push_back(std::move(item)); } return true; } @@ -200,47 +259,45 @@ template class Repeated { } return out; } + Message::ValidationResult validate(const Buffer::Instance& data, const uint64_t start_offset, + uint64_t& pos, uint64_t& left) { + if ((data.length() - pos) < left) { + return Message::ValidationNeedMoreData; + } -private: - std::vector> value_; -}; - -// Interface to Postgres message class. -class Message { -public: - virtual ~Message() = default; + // Validate until the end of the message. + uint64_t orig_pos = pos; + uint64_t orig_left = left; + while (left != 0) { + auto item = std::make_unique(); + Message::ValidationResult result = item->validate(data, start_offset, pos, left); + if (Message::ValidationOK != result) { + pos = orig_pos; + left = orig_left; + value_.clear(); + return result; + } + value_.push_back(std::move(item)); + } - // read method should read only as many bytes from data - // buffer as it is indicated in message's length field. - // "length" parameter indicates how many bytes were indicated in Postgres message's - // length field. "data" buffer may contain more bytes than "length". - virtual bool read(const Buffer::Instance& data, const uint64_t length) PURE; + return Message::ValidationOK; + } - // toString method provides displayable representation of - // the Postgres message. - virtual std::string toString() const PURE; +private: + std::vector> value_; }; // Sequence is tuple like structure, which binds together // set of several fields of different types. template class Sequence; -template -class Sequence : public Message { +template class Sequence { FirstField first_; Sequence remaining_; public: Sequence() = default; - std::string toString() const override { - return absl::StrCat(first_.toString(), remaining_.toString()); - } - - bool read(const Buffer::Instance& data, const uint64_t length) override { - uint64_t pos = 0; - uint64_t left = length; - return read(data, pos, left); - } + std::string toString() const { return absl::StrCat(first_.toString(), remaining_.toString()); } /** * Implementation of "read" method for variadic template. @@ -255,21 +312,56 @@ class Sequence : public Message { } return remaining_.read(data, pos, left); } + + Message::ValidationResult validate(const Buffer::Instance& data, const uint64_t start_offset, + uint64_t& pos, uint64_t& left) { + Message::ValidationResult result = first_.validate(data, start_offset, pos, left); + if (result != Message::ValidationOK) { + return result; + } + return remaining_.validate(data, start_offset, pos, left); + } }; // Terminal template definition for variadic Sequence template. -template <> class Sequence<> : public Message { +template <> class Sequence<> { public: Sequence<>() = default; - std::string toString() const override { return ""; } + std::string toString() const { return ""; } bool read(const Buffer::Instance&, uint64_t&, uint64_t&) { return true; } - bool read(const Buffer::Instance&, const uint64_t) override { return true; } + Message::ValidationResult validate(const Buffer::Instance&, const uint64_t, uint64_t&, + uint64_t& left) { + return left == 0 ? Message::ValidationOK : Message::ValidationFailed; + } +}; + +template class MessageImpl : public Message, public Sequence { +public: + ~MessageImpl() override = default; + bool read(const Buffer::Instance& data, const uint64_t length) override { + // Do not call read unless validation was successful. + ASSERT(validation_result_ == ValidationOK); + uint64_t pos = 0; + uint64_t left = length; + return Sequence::read(data, pos, left); + } + Message::ValidationResult validate(const Buffer::Instance& data, const uint64_t start_pos, + const uint64_t length) override { + uint64_t pos = start_pos; + uint64_t left = length; + validation_result_ = Sequence::validate(data, start_pos, pos, left); + return validation_result_; + } + std::string toString() const override { return Sequence::toString(); } + +private: + // Message::ValidationResult validation_result_; }; // Helper function to create pointer to a Sequence structure and is used by Postgres // decoder after learning the type of Postgres message. template std::unique_ptr createMsgBodyReader() { - return std::make_unique>(); + return std::make_unique>(); } } // namespace PostgresProxy diff --git a/test/extensions/filters/network/postgres_proxy/BUILD b/test/extensions/filters/network/postgres_proxy/BUILD index f121e6b178e2e..10a2680e00f95 100644 --- a/test/extensions/filters/network/postgres_proxy/BUILD +++ b/test/extensions/filters/network/postgres_proxy/BUILD @@ -19,6 +19,7 @@ envoy_extension_cc_test_library( extension_name = "envoy.filters.network.postgres_proxy", deps = [ "//source/common/buffer:buffer_lib", + "//source/extensions/filters/network/postgres_proxy:filter", ], ) diff --git a/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc b/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc index cc7f65cdcd035..3c6e05bf9f1fd 100644 --- a/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc +++ b/test/extensions/filters/network/postgres_proxy/postgres_decoder_test.cc @@ -33,7 +33,7 @@ class PostgresProxyDecoderTestBase { PostgresProxyDecoderTestBase() { decoder_ = std::make_unique(&callbacks_); decoder_->initialize(); - decoder_->setStartup(false); + decoder_->state(DecoderImpl::State::InSyncState); } protected: @@ -60,6 +60,10 @@ class PostgresProxyFrontendEncrDecoderTest : public PostgresProxyDecoderTestBase class PostgresProxyBackendDecoderTest : public PostgresProxyDecoderTestBase, public ::testing::TestWithParam {}; +class PostgresProxyBackendStatementTest + : public PostgresProxyDecoderTestBase, + public ::testing::TestWithParam> {}; + class PostgresProxyErrorTest : public PostgresProxyDecoderTestBase, public ::testing::TestWithParam> {}; @@ -75,7 +79,7 @@ class PostgresProxyNoticeTest // startup message the server should start using message format // with command as 1st byte. TEST_F(PostgresProxyDecoderTest, StartupMessage) { - decoder_->setStartup(true); + decoder_->state(DecoderImpl::State::InitState); buf_[0] = '\0'; // Startup message has the following structure: @@ -98,29 +102,25 @@ TEST_F(PostgresProxyDecoderTest, StartupMessage) { // Some other attribute data_.add("attribute"); // 9 bytes data_.add(buf_, 1); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::NeedMoreData); data_.add("blah"); // 4 bytes + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::NeedMoreData); data_.add(buf_, 1); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); ASSERT_THAT(data_.length(), 0); + // Decoder should move to InSyncState + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Verify parsing attributes ASSERT_THAT(decoder_->getAttributes().at("user"), "postgres"); ASSERT_THAT(decoder_->getAttributes().at("database"), "testdb"); // This attribute should not be found ASSERT_THAT(decoder_->getAttributes().find("no"), decoder_->getAttributes().end()); - - // Now feed normal message with 1bytes as command. - data_.add("P"); - // Add length. - data_.writeBEInt(6); // 4 bytes of length + 2 bytes of data. - data_.add("AB"); - decoder_->onData(data_, true); - ASSERT_THAT(data_.length(), 0); } // Test verifies that when Startup message does not carry // "database" attribute, it is derived from "user". TEST_F(PostgresProxyDecoderTest, StartupMessageNoAttr) { - decoder_->setStartup(true); + decoder_->state(DecoderImpl::State::InitState); buf_[0] = '\0'; // Startup message has the following structure: @@ -141,7 +141,8 @@ TEST_F(PostgresProxyDecoderTest, StartupMessageNoAttr) { data_.add(buf_, 1); data_.add("blah"); // 4 bytes data_.add(buf_, 1); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 0); // Verify parsing attributes @@ -151,53 +152,109 @@ TEST_F(PostgresProxyDecoderTest, StartupMessageNoAttr) { ASSERT_THAT(decoder_->getAttributes().find("no"), decoder_->getAttributes().end()); } +TEST_F(PostgresProxyDecoderTest, InvalidStartupMessage) { + decoder_->state(DecoderImpl::State::InitState); + + // Create a bogus message with incorrect syntax. + // Length is 10 bytes. + data_.writeBEInt(10); + for (auto i = 0; i < 6; i++) { + data_.writeBEInt(i); + } + + // Decoder should move to OutOfSync state. + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::OutOfSyncState); + ASSERT_THAT(data_.length(), 0); + + // All-zeros message. + data_.writeBEInt(0); + for (auto i = 0; i < 6; i++) { + data_.writeBEInt(0); + } + + // Decoder should move to OutOfSync state. + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::OutOfSyncState); + ASSERT_THAT(data_.length(), 0); +} + +// Test that decoder does not crash when it receives +// random data in InitState. +TEST_F(PostgresProxyDecoderTest, StartupMessageRandomData) { + srand(time(nullptr)); + for (auto i = 0; i < 10000; i++) { + decoder_->state(DecoderImpl::State::InSyncState); + // Generate random length. + uint32_t len = rand() % 20000; + // Now fill the buffer with random data. + for (uint32_t j = 0; j < len; j++) { + data_.writeBEInt(rand() % 1024); + uint8_t data = static_cast(rand() % 256); + data_.writeBEInt(data); + } + // Feed the buffer to the decoder. It should not crash. + decoder_->onData(data_, true); + + // Reset the buffer for the next iteration. + data_.drain(data_.length()); + } +} + // Test processing messages which map 1:1 with buffer. // The buffer contains just a single entire message and // nothing more. TEST_F(PostgresProxyDecoderTest, ReadingBufferSingleMessages) { - + decoder_->state(DecoderImpl::State::InSyncState); // Feed empty buffer - should not crash. - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::NeedMoreData); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Put one byte. This is not enough to parse the message and that byte // should stay in the buffer. - data_.add("P"); - decoder_->onData(data_, true); + data_.add("H"); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::NeedMoreData); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 1); // Add length of 4 bytes. It would mean completely empty message. // but it should be consumed. data_.writeBEInt(4); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 0); // Create a message with 5 additional bytes. - data_.add("P"); + data_.add("d"); // Add length. data_.writeBEInt(9); // 4 bytes of length field + 5 of data. data_.add(buf_, 5); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); ASSERT_THAT(data_.length(), 0); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } // Test simulates situation when decoder is called with incomplete message. // The message should not be processed until the buffer is filled // with missing bytes. TEST_F(PostgresProxyDecoderTest, ReadingBufferLargeMessages) { + decoder_->state(DecoderImpl::State::InSyncState); // Fill the buffer with message of 100 bytes long // but the buffer contains only 98 bytes. // It should not be processed. - data_.add("P"); + data_.add("d"); // Add length. data_.writeBEInt(100); // This also includes length field data_.add(buf_, 94); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::NeedMoreData); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // The buffer contains command (1 byte), length (4 bytes) and 94 bytes of message. ASSERT_THAT(data_.length(), 99); // Add 2 missing bytes and feed again to decoder. data_.add("AB"); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 0); } @@ -205,14 +262,15 @@ TEST_F(PostgresProxyDecoderTest, ReadingBufferLargeMessages) { // message. Call to the decoder should consume only one message // at a time and only when the buffer contains the entire message. TEST_F(PostgresProxyDecoderTest, TwoMessagesInOneBuffer) { + decoder_->state(DecoderImpl::State::InSyncState); // Create the first message of 50 bytes long (+1 for command). - data_.add("P"); + data_.add("d"); // Add length. data_.writeBEInt(50); data_.add(buf_, 46); // Create the second message of 50 + 46 bytes (+1 for command). - data_.add("P"); + data_.add("d"); // Add length. data_.writeBEInt(96); data_.add(buf_, 46); @@ -223,49 +281,72 @@ TEST_F(PostgresProxyDecoderTest, TwoMessagesInOneBuffer) { // 2nd: command (1 byte), length (4 bytes), 92 bytes of data ASSERT_THAT(data_.length(), 148); // Process the first message. - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 97); // Process the second message. - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_THAT(data_.length(), 0); } TEST_F(PostgresProxyDecoderTest, Unknown) { + decoder_->state(DecoderImpl::State::InSyncState); // Create invalid message. The first byte is invalid "=" // Message must be at least 5 bytes to be parsed. EXPECT_CALL(callbacks_, incMessagesUnknown()); createPostgresMsg(data_, "=", "some not important string which will be ignored anyways"); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(data_.length(), 0); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); +} + +// Test verifies that decoder goes into OutOfSyncState when +// it encounters a message with wrong syntax. +TEST_F(PostgresProxyDecoderTest, IncorrectMessages) { + decoder_->state(DecoderImpl::State::InSyncState); + + // Create incorrect message. Message syntax is + // 1 byte type ('f'), 4 bytes of length and zero terminated string. + data_.add("f"); + data_.writeBEInt(8); + // Do not write terminating zero for the string. + data_.add("test"); + + // The decoder will indicate that is is ready for more data, but + // will enter OutOfSyncState. + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::OutOfSyncState); } -// Test if each frontend command calls incMessagesFrontend() method. -TEST_P(PostgresProxyFrontendDecoderTest, FrontendInc) { +// Test if frontend command calls incMessagesFrontend() method. +TEST_F(PostgresProxyFrontendDecoderTest, FrontendInc) { + decoder_->state(DecoderImpl::State::InSyncState); EXPECT_CALL(callbacks_, incMessagesFrontend()); - createPostgresMsg(data_, GetParam(), "SELECT 1;"); - decoder_->onData(data_, true); + createPostgresMsg(data_, "f", "some text"); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Make sure that decoder releases memory used during message processing. ASSERT_TRUE(decoder_->getMessage().empty()); } -// Run the above test for each frontend message. -INSTANTIATE_TEST_SUITE_P(FrontEndMessagesTests, PostgresProxyFrontendDecoderTest, - ::testing::Values("B", "C", "d", "c", "f", "D", "E", "H", "F", "p", "P", - "p", "Q", "S", "X")); - // Test if X message triggers incRollback and sets proper state in transaction. TEST_F(PostgresProxyFrontendDecoderTest, TerminateMessage) { + decoder_->state(DecoderImpl::State::InSyncState); // Set decoder state NOT to be in_transaction. decoder_->getSession().setInTransaction(false); EXPECT_CALL(callbacks_, incTransactionsRollback()).Times(0); createPostgresMsg(data_, "X"); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Now set the decoder to be in_transaction state. decoder_->getSession().setInTransaction(true); EXPECT_CALL(callbacks_, incTransactionsRollback()); createPostgresMsg(data_, "X"); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); ASSERT_FALSE(decoder_->getSession().inTransaction()); } @@ -273,7 +354,8 @@ TEST_F(PostgresProxyFrontendDecoderTest, TerminateMessage) { TEST_F(PostgresProxyFrontendDecoderTest, QueryMessage) { EXPECT_CALL(callbacks_, processQuery); createPostgresMsg(data_, "Q", "SELECT * FROM whatever;"); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } // Parse message has optional Query name which may be in front of actual @@ -295,7 +377,8 @@ TEST_F(PostgresProxyFrontendDecoderTest, ParseMessage) { query_name.reserve(1); query_name += '\0'; createPostgresMsg(data_, "P", query_name + query + query_params); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Message with optional name query_name query_name.clear(); @@ -303,21 +386,18 @@ TEST_F(PostgresProxyFrontendDecoderTest, ParseMessage) { query_name += "P0_8"; query_name += '\0'; createPostgresMsg(data_, "P", query_name + query + query_params); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, true), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } -// Test if each backend command calls incMessagesBackend()) method. -TEST_P(PostgresProxyBackendDecoderTest, BackendInc) { +// Test if backend command calls incMessagesBackend()) method. +TEST_F(PostgresProxyBackendDecoderTest, BackendInc) { EXPECT_CALL(callbacks_, incMessagesBackend()); - createPostgresMsg(data_, GetParam(), "Some not important message"); - decoder_->onData(data_, false); + createPostgresMsg(data_, "I"); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } -// Run the above test for each backend message. -INSTANTIATE_TEST_SUITE_P(BackendMessagesTests, PostgresProxyBackendDecoderTest, - ::testing::Values("R", "K", "2", "3", "C", "d", "c", "G", "H", "D", "I", - "E", "V", "v", "n", "N", "A", "t", "S", "1", "s", "Z", - "T")); // Test parsing backend messages. // The parser should react only to the first word until the space. TEST_F(PostgresProxyBackendDecoderTest, ParseStatement) { @@ -325,80 +405,93 @@ TEST_F(PostgresProxyBackendDecoderTest, ParseStatement) { // Rollback counter should be bumped up. EXPECT_CALL(callbacks_, incTransactionsRollback()); createPostgresMsg(data_, "C", "ROLLBACK 123"); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); // Now try just keyword without a space at the end. EXPECT_CALL(callbacks_, incTransactionsRollback()); createPostgresMsg(data_, "C", "ROLLBACK"); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); // Partial message should be ignored. EXPECT_CALL(callbacks_, incTransactionsRollback()).Times(0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); createPostgresMsg(data_, "C", "ROLL"); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); // Keyword without a space should be ignored. EXPECT_CALL(callbacks_, incTransactionsRollback()).Times(0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); createPostgresMsg(data_, "C", "ROLLBACK123"); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); } // Test Backend messages and make sure that they // trigger proper stats updates. TEST_F(PostgresProxyDecoderTest, Backend) { + decoder_->state(DecoderImpl::State::InSyncState); // C message EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); createPostgresMsg(data_, "C", "BEGIN 123"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); ASSERT_TRUE(decoder_->getSession().inTransaction()); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); createPostgresMsg(data_, "C", "START TR"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); EXPECT_CALL(callbacks_, incTransactionsCommit()); createPostgresMsg(data_, "C", "COMMIT"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Select)); EXPECT_CALL(callbacks_, incTransactionsCommit()); createPostgresMsg(data_, "C", "SELECT"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Other)); EXPECT_CALL(callbacks_, incTransactionsRollback()); createPostgresMsg(data_, "C", "ROLLBACK"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Insert)); EXPECT_CALL(callbacks_, incTransactionsCommit()); createPostgresMsg(data_, "C", "INSERT 1"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Update)); EXPECT_CALL(callbacks_, incTransactionsCommit()); createPostgresMsg(data_, "C", "UPDATE 123"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); EXPECT_CALL(callbacks_, incStatements(DecoderCallbacks::StatementType::Delete)); EXPECT_CALL(callbacks_, incTransactionsCommit()); createPostgresMsg(data_, "C", "DELETE 88"); - decoder_->onData(data_, false); - data_.drain(data_.length()); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); + ASSERT_THAT(data_.length(), 0); } // Test checks deep inspection of the R message. @@ -412,7 +505,8 @@ TEST_F(PostgresProxyBackendDecoderTest, AuthenticationMsg) { // sessions must not be increased. EXPECT_CALL(callbacks_, incSessionsUnencrypted()).Times(0); createPostgresMsg(data_, "R", "blah blah"); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); // Create the correct payload which means that @@ -423,7 +517,8 @@ TEST_F(PostgresProxyBackendDecoderTest, AuthenticationMsg) { data_.writeBEInt(8); // Add 4-byte code. data_.writeBEInt(0); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); data_.drain(data_.length()); } @@ -432,7 +527,8 @@ TEST_F(PostgresProxyBackendDecoderTest, AuthenticationMsg) { TEST_P(PostgresProxyErrorTest, ParseErrorMsgs) { EXPECT_CALL(callbacks_, incErrors(std::get<1>(GetParam()))); createPostgresMsg(data_, "E", std::get<0>(GetParam())); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } INSTANTIATE_TEST_SUITE_P( @@ -461,7 +557,8 @@ INSTANTIATE_TEST_SUITE_P( TEST_P(PostgresProxyNoticeTest, ParseNoticeMsgs) { EXPECT_CALL(callbacks_, incNotices(std::get<1>(GetParam()))); createPostgresMsg(data_, "N", std::get<0>(GetParam())); - decoder_->onData(data_, false); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } INSTANTIATE_TEST_SUITE_P( @@ -478,10 +575,10 @@ INSTANTIATE_TEST_SUITE_P( // that protocol uses encryption. TEST_P(PostgresProxyFrontendEncrDecoderTest, EncyptedTraffic) { // Set decoder to wait for initial message. - decoder_->setStartup(true); + decoder_->state(DecoderImpl::State::InitState); // Initial state is no-encryption. - ASSERT_FALSE(decoder_->encrypted()); + // ASSERT_FALSE(decoder_->encrypted()); // Indicate that decoder should continue with processing the message. ON_CALL(callbacks_, onSSLRequest).WillByDefault(testing::Return(true)); @@ -493,8 +590,11 @@ TEST_P(PostgresProxyFrontendEncrDecoderTest, EncyptedTraffic) { // 1234 in the most significant 16 bits, and some code in the least significant 16 bits. // Add 4 bytes long code data_.writeBEInt(GetParam()); - decoder_->onData(data_, true); - ASSERT_TRUE(decoder_->encrypted()); + // Decoder should indicate that it is ready for mode data and entered + // encrypted state. + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::EncryptedState); + // ASSERT_TRUE(decoder_->encrypted()); // Decoder should drain data. ASSERT_THAT(data_.length(), 0); @@ -503,7 +603,8 @@ TEST_P(PostgresProxyFrontendEncrDecoderTest, EncyptedTraffic) { EXPECT_CALL(callbacks_, incMessagesFrontend()).Times(0); createPostgresMsg(data_, "P", "Some message just to fill the payload."); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::EncryptedState); // Decoder should drain data. ASSERT_THAT(data_.length(), 0); } @@ -517,7 +618,7 @@ INSTANTIATE_TEST_SUITE_P(FrontendEncryptedMessagesTests, PostgresProxyFrontendEn // Test onSSLRequest callback. TEST_F(PostgresProxyDecoderTest, TerminateSSL) { // Set decoder to wait for initial message. - decoder_->setStartup(true); + decoder_->state(DecoderImpl::State::InitState); // Indicate that decoder should not continue with processing the message // because filter will try to terminate SSL session. @@ -528,9 +629,10 @@ TEST_F(PostgresProxyDecoderTest, TerminateSSL) { // 1234 in the most significant 16 bits, and some code in the least significant 16 bits. // Add 4 bytes long code data_.writeBEInt(80877103); - decoder_->onData(data_, true); + ASSERT_THAT(decoder_->onData(data_, false), Decoder::Result::Stopped); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InitState); - // Decoder should interpret the session as encrypted stream. + // Decoder should interpret the session as clear-text stream. ASSERT_FALSE(decoder_->encrypted()); } @@ -569,11 +671,10 @@ class FakeBuffer : public Buffer::Instance { // Test verifies that decoder calls Buffer::linearize method // for messages which have associated 'action'. TEST_F(PostgresProxyDecoderTest, Linearize) { + decoder_->state(DecoderImpl::State::InSyncState); testing::NiceMock fake_buf; uint8_t body[] = "test\0"; - decoder_->setStartup(false); - // Simulate that decoder reads message which needs processing. // Query 'Q' message's body is just string. // Message header is 5 bytes and body will contain string "test\0". @@ -600,7 +701,8 @@ TEST_F(PostgresProxyDecoderTest, Linearize) { // It should call "Buffer::linearize". EXPECT_CALL(fake_buf, linearize).WillOnce([&](uint32_t) -> void* { return body; }); - decoder_->onData(fake_buf, false); + ASSERT_THAT(decoder_->onData(fake_buf, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); // Simulate that decoder reads message which does not need processing. // BindComplete message has type '2' and empty body. @@ -622,7 +724,8 @@ TEST_F(PostgresProxyDecoderTest, Linearize) { // Make sure that decoder does not call linearize. EXPECT_CALL(fake_buf, linearize).Times(0); - decoder_->onData(fake_buf, false); + ASSERT_THAT(decoder_->onData(fake_buf, false), Decoder::Result::ReadyForNext); + ASSERT_THAT(decoder_->state(), DecoderImpl::State::InSyncState); } } // namespace PostgresProxy diff --git a/test/extensions/filters/network/postgres_proxy/postgres_filter_test.cc b/test/extensions/filters/network/postgres_proxy/postgres_filter_test.cc index 1214dc4c661e3..6564883727602 100644 --- a/test/extensions/filters/network/postgres_proxy/postgres_filter_test.cc +++ b/test/extensions/filters/network/postgres_proxy/postgres_filter_test.cc @@ -84,7 +84,7 @@ TEST_P(PostgresFilterTest, ReadData) { EXPECT_CALL(*decoderPtr, onData) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(data.length()); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; }))); std::get<0>(GetParam())(filter_.get(), data_, false); ASSERT_THAT(std::get<1>(GetParam())(filter_.get()), 0); @@ -93,11 +93,11 @@ TEST_P(PostgresFilterTest, ReadData) { EXPECT_CALL(*decoderPtr, onData) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(100); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; }))) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(156); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; }))); std::get<0>(GetParam())(filter_.get(), data_, false); ASSERT_THAT(std::get<1>(GetParam())(filter_.get()), 0); @@ -108,15 +108,15 @@ TEST_P(PostgresFilterTest, ReadData) { EXPECT_CALL(*decoderPtr, onData) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(100); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; }))) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(100); - return Decoder::ReadyForNext; + return Decoder::Result::ReadyForNext; }))) .WillOnce(WithArgs<0, 1>(Invoke([](Buffer::Instance& data, bool) -> Decoder::Result { data.drain(0); - return Decoder::NeedMoreData; + return Decoder::Result::NeedMoreData; }))); std::get<0>(GetParam())(filter_.get(), data_, false); ASSERT_THAT(std::get<1>(GetParam())(filter_.get()), 56); @@ -135,7 +135,7 @@ INSTANTIATE_TEST_SUITE_P(ProcessDataTests, PostgresFilterTest, // It expects that certain statistics are updated. TEST_F(PostgresFilterTest, BackendMsgsStats) { // pretend that startup message has been received. - static_cast(filter_->getDecoder())->setStartup(false); + static_cast(filter_->getDecoder())->state(DecoderImpl::State::InSyncState); // unknown message createPostgresMsg(data_, "=", "blah blah blah"); @@ -230,7 +230,7 @@ TEST_F(PostgresFilterTest, BackendMsgsStats) { // verifies that statistic counters are increased. TEST_F(PostgresFilterTest, ErrorMsgsStats) { // Pretend that startup message has been received. - static_cast(filter_->getDecoder())->setStartup(false); + static_cast(filter_->getDecoder())->state(DecoderImpl::State::InSyncState); createPostgresMsg(data_, "E", "SERRORVERRORC22012"); filter_->onWrite(data_, false); @@ -257,7 +257,7 @@ TEST_F(PostgresFilterTest, ErrorMsgsStats) { // that corresponding stats counters are updated. TEST_F(PostgresFilterTest, NoticeMsgsStats) { // Pretend that startup message has been received. - static_cast(filter_->getDecoder())->setStartup(false); + static_cast(filter_->getDecoder())->state(DecoderImpl::State::InSyncState); createPostgresMsg(data_, "N", "SblalalaC2345"); filter_->onWrite(data_, false); @@ -304,7 +304,7 @@ TEST_F(PostgresFilterTest, EncryptedSessionStats) { // Postgres metadata. TEST_F(PostgresFilterTest, MetadataIncorrectSQL) { // Pretend that startup message has been received. - static_cast(filter_->getDecoder())->setStartup(false); + static_cast(filter_->getDecoder())->state(DecoderImpl::State::InSyncState); setMetadata(); createPostgresMsg(data_, "Q", "BLAH blah blah"); @@ -322,7 +322,7 @@ TEST_F(PostgresFilterTest, MetadataIncorrectSQL) { // and it happens only when parse_sql flag is true. TEST_F(PostgresFilterTest, QueryMessageMetadata) { // Pretend that startup message has been received. - static_cast(filter_->getDecoder())->setStartup(false); + static_cast(filter_->getDecoder())->state(DecoderImpl::State::InSyncState); setMetadata(); // Disable creating parsing SQL and creating metadata. diff --git a/test/extensions/filters/network/postgres_proxy/postgres_message_test.cc b/test/extensions/filters/network/postgres_proxy/postgres_message_test.cc index ec7b8e1b713ba..9fb2d5c89277d 100644 --- a/test/extensions/filters/network/postgres_proxy/postgres_message_test.cc +++ b/test/extensions/filters/network/postgres_proxy/postgres_message_test.cc @@ -28,7 +28,17 @@ TYPED_TEST_SUITE(IntTest, IntTypes); TYPED_TEST(IntTest, BasicRead) { this->data_.template writeBEInt().get())>(12); uint64_t pos = 0; - uint64_t left = this->data_.length(); + uint64_t left; + // Simulate that message is too short. + left = sizeof(TypeParam) - 1; + ASSERT_THAT(Message::ValidationFailed, this->field_.validate(this->data_, 0, pos, left)); + // Single 4-byte int. Message length is correct. + left = sizeof(TypeParam); + ASSERT_THAT(Message::ValidationOK, this->field_.validate(this->data_, 0, pos, left)); + + // Read the value after successful validation. + pos = 0; + left = sizeof(TypeParam); ASSERT_TRUE(this->field_.read(this->data_, pos, left)); ASSERT_THAT(this->field_.toString(), "[12]"); @@ -46,6 +56,10 @@ TYPED_TEST(IntTest, ReadWithLeftovers) { this->data_.template writeBEInt(11); uint64_t pos = 0; uint64_t left = this->data_.length(); + ASSERT_THAT(Message::ValidationOK, this->field_.validate(this->data_, 0, pos, left)); + + pos = 0; + left = this->data_.length(); ASSERT_TRUE(this->field_.read(this->data_, pos, left)); ASSERT_THAT(this->field_.toString(), "[12]"); // pos should be moved forward by the number of bytes read. @@ -59,8 +73,13 @@ TYPED_TEST(IntTest, ReadAtOffset) { // write 1 byte before the actual value. this->data_.template writeBEInt(11); this->data_.template writeBEInt().get())>(12); + uint64_t pos = 1; uint64_t left = this->data_.length() - 1; + ASSERT_THAT(Message::ValidationOK, this->field_.validate(this->data_, 1, pos, left)); + + pos = 1; + left = this->data_.length() - 1; ASSERT_TRUE(this->field_.read(this->data_, pos, left)); ASSERT_THAT(this->field_.toString(), "[12]"); // pos should be moved forward by the number of bytes read. @@ -73,8 +92,9 @@ TYPED_TEST(IntTest, NotEnoughData) { this->data_.template writeBEInt().get())>(12); // Start from offset 1. There is not enough data in the buffer for the required type. uint64_t pos = 1; - uint64_t left = this->data_.length() - pos; - ASSERT_FALSE(this->field_.read(this->data_, pos, left)); + uint64_t left = this->data_.length(); + + ASSERT_THAT(this->field_.validate(this->data_, 0, pos, left), Message::ValidationNeedMoreData); } // Byte1 should format content as char. @@ -86,6 +106,12 @@ TEST(Byte1, Formatting) { uint64_t pos = 0; uint64_t left = 1; + ASSERT_THAT(Message::ValidationOK, field.validate(data, 0, pos, left)); + ASSERT_THAT(pos, 1); + ASSERT_THAT(left, 0); + + pos = 0; + left = 1; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 1); ASSERT_THAT(left, 0); @@ -99,9 +125,21 @@ TEST(StringType, SingleString) { Buffer::OwnedImpl data; data.add("test"); - data.writeBEInt(0); + // Passed length 3 is too short. uint64_t pos = 0; - uint64_t left = 5; + uint64_t left = 3; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + // Correct length, but terminating zero is missing. + left = 5; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + // Add terminating zero. + data.writeBEInt(0); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 5); + ASSERT_THAT(left, 0); + + pos = 0; + left = 5; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 5); ASSERT_THAT(left, 0); @@ -110,42 +148,6 @@ TEST(StringType, SingleString) { ASSERT_THAT(out, "[test]"); } -TEST(StringType, MultipleStrings) { - String field; - - // Add 3 strings. - Buffer::OwnedImpl data; - data.add("test1"); - data.writeBEInt(0); - data.add("test2"); - data.writeBEInt(0); - data.add("test3"); - data.writeBEInt(0); - uint64_t pos = 0; - uint64_t left = 3 * 6; - - // Read the first string. - ASSERT_TRUE(field.read(data, pos, left)); - ASSERT_THAT(pos, 1 * 6); - ASSERT_THAT(left, 2 * 6); - auto out = field.toString(); - ASSERT_THAT(out, "[test1]"); - - // Read the second string. - ASSERT_TRUE(field.read(data, pos, left)); - ASSERT_THAT(pos, 2 * 6); - ASSERT_THAT(left, 1 * 6); - out = field.toString(); - ASSERT_THAT(out, "[test2]"); - - // Read the third string. - ASSERT_TRUE(field.read(data, pos, left)); - ASSERT_THAT(pos, 3 * 6); - ASSERT_THAT(left, 0); - out = field.toString(); - ASSERT_THAT(out, "[test3]"); -} - TEST(StringType, NoTerminatingByte) { String field; @@ -153,7 +155,9 @@ TEST(StringType, NoTerminatingByte) { data.add("test"); uint64_t pos = 0; uint64_t left = 4; - ASSERT_FALSE(field.read(data, pos, left)); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + left = 5; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); } // ByteN type is always placed at the end of Postgres message. @@ -169,10 +173,31 @@ TEST(ByteN, BasicTest) { data.writeBEInt(i); } uint64_t pos = 0; - uint64_t left = 10; + uint64_t left; + + // Since ByteN structure does not contain length field, any + // value less than number of bytes in the buffer should + // pass validation. + pos = 0; + left = 0; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 0); + ASSERT_THAT(left, 0); + pos = 0; + left = 1; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 1); + ASSERT_THAT(left, 0); + pos = 0; + left = 4; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 4); + ASSERT_THAT(left, 0); + + pos = 0; + left = 10; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 10); - // One byte should be left in the buffer. ASSERT_THAT(left, 0); auto out = field.toString(); @@ -189,7 +214,7 @@ TEST(ByteN, NotEnoughData) { } uint64_t pos = 0; uint64_t left = 11; - ASSERT_FALSE(field.read(data, pos, left)); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); } TEST(ByteN, Empty) { @@ -199,6 +224,7 @@ TEST(ByteN, Empty) { // Write nothing to data buffer. uint64_t pos = 0; uint64_t left = 0; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); ASSERT_TRUE(field.read(data, pos, left)); auto out = field.toString(); @@ -208,12 +234,49 @@ TEST(ByteN, Empty) { // VarByteN type. It contains 4 bytes length field with value which follows. TEST(VarByteN, BasicTest) { VarByteN field; - Buffer::OwnedImpl data; + + uint64_t pos = 0; + uint64_t left = 0; + // Simulate that message ended and VarByteN's length fields sticks past the + // message boundary. + data.writeBEInt(5); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + // Write VarByteN with length equal to zero. No value follows. - data.writeBEInt(0); + // Set structure length to be -1 (means no payload). + left = 4; + data.drain(data.length()); + data.writeBEInt(-1); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + // The same for structure length 0. + pos = 0; + left = 4; + data.drain(data.length()); + data.writeBEInt(0); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + + // Simulate that VarByteN would extend past message boundary. + data.drain(data.length()); + data.writeBEInt(30); + pos = 0; + left = 4; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + + // Simulate that VarByteN length is 6, there are 6 bytes left to the + // message boundary, but buffer contains only 4 bytes. + data.drain(data.length()); + data.writeBEInt(6); + data.writeBEInt(16); + pos = 0; + left = 6; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + + data.drain(data.length()); + // Write first value. + data.writeBEInt(0); - // Write value with 5 bytes. + // Write 2nd value with 5 bytes. data.writeBEInt(5); for (auto i = 0; i < 5; i++) { data.writeBEInt(10 + i); @@ -222,11 +285,15 @@ TEST(VarByteN, BasicTest) { // Write special case value with length -1. No value follows. data.writeBEInt(-1); - uint64_t pos = 0; - uint64_t left = 4 + 4 + 5 + 4; + pos = 0; + left = 4 + 4 + 5 + 4; uint64_t expected_left = left; - + uint64_t orig_pos = pos; + uint64_t orig_left = left; // Read the first value. + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + pos = orig_pos; + left = orig_left; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 4); expected_left -= 4; @@ -235,6 +302,11 @@ TEST(VarByteN, BasicTest) { ASSERT_TRUE(out.find("0 bytes") != std::string::npos); // Read the second value. + orig_pos = pos; + orig_left = left; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + pos = orig_pos; + left = orig_left; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 4 + 4 + 5); expected_left -= (4 + 5); @@ -244,6 +316,11 @@ TEST(VarByteN, BasicTest) { ASSERT_TRUE(out.find("10 11 12 13 14") != std::string::npos); // Read the third value. + orig_pos = pos; + orig_left = left; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + pos = orig_pos; + left = orig_left; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 4 + 4 + 5 + 4); expected_left -= 4; @@ -252,47 +329,31 @@ TEST(VarByteN, BasicTest) { ASSERT_TRUE(out.find("-1 bytes") != std::string::npos); } -TEST(VarByteN, NotEnoughLengthData) { - VarByteN field; - - Buffer::OwnedImpl data; - // Write 3 bytes. Minimum for this type is 4 bytes of length. - data.writeBEInt(0); - data.writeBEInt(1); - data.writeBEInt(2); - - uint64_t pos = 0; - uint64_t left = 3; - ASSERT_FALSE(field.read(data, pos, left)); -} - -TEST(VarByteN, NotEnoughValueData) { - VarByteN field; - - Buffer::OwnedImpl data; - // Write length of the value to be 5 bytes, but supply only 4 bytes. - data.writeBEInt(5); - data.writeBEInt(0); - data.writeBEInt(1); - data.writeBEInt(2); - data.writeBEInt(3); - - uint64_t pos = 0; - uint64_t left = 5 + 4; - ASSERT_FALSE(field.read(data, pos, left)); -} - // Array composite type tests. TEST(Array, SingleInt) { Array field; Buffer::OwnedImpl data; - // Write the number of elements in the array. - data.writeBEInt(1); + // Simulate that message ends before the array. + uint64_t pos = 0; + uint64_t left = 1; + data.writeBEInt(1); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + + // Write the value of the element into the array. + data.drain(data.length()); + data.writeBEInt(1); data.writeBEInt(123); + // Simulate that message length end before end of array. + left = 5; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); - uint64_t pos = 0; - uint64_t left = 2 + 4; + left = 6; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 6); + ASSERT_THAT(left, 0); + pos = 0; + left = 6; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 6); ASSERT_THAT(left, 0); @@ -306,14 +367,29 @@ TEST(Array, MultipleInts) { Array field; Buffer::OwnedImpl data; - // Write 3 elements into array. + // Write 3 as size of array, but add only 2 elements into array. data.writeBEInt(3); data.writeBEInt(211); data.writeBEInt(212); - data.writeBEInt(213); uint64_t pos = 0; uint64_t left = 2 + 3 * 1; + + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + + // Add the third element. + data.writeBEInt(213); + + // Simulate that message ends before end of the array. + left = 2 + 3 * 1 - 1; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + + left = 2 + 3 * 1; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 5); + ASSERT_THAT(left, 0); + pos = 0; + left = 2 + 3 * 1; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 5); ASSERT_THAT(left, 0); @@ -334,6 +410,11 @@ TEST(Array, Empty) { uint64_t pos = 0; uint64_t left = 2; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 2); + ASSERT_THAT(left, 0); + pos = 0; + left = 2; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 2); ASSERT_THAT(left, 0); @@ -352,7 +433,7 @@ TEST(Array, NotEnoughDataForLength) { uint64_t pos = 0; uint64_t left = 1; - ASSERT_FALSE(field.read(data, pos, left)); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); } // Test situation when there is not enough data in the buffer to read one of the elements @@ -370,7 +451,7 @@ TEST(Array, NotEnoughDataForValues) { uint64_t pos = 0; uint64_t left = 2 + 4 + 2; - ASSERT_FALSE(field.read(data, pos, left)); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); } // Repeated composite type tests. @@ -382,15 +463,37 @@ TEST(Repeated, BasicTestWithStrings) { // It will be ignored. data.writeBEInt(101); data.writeBEInt(102); - // Now write 3 strings. Each terminated by zero byte. + uint64_t pos = 5; + uint64_t left = 5; + // Write the first string without terminating zero. data.add("test1"); - data.writeBEInt(0); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + left = 6; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + // Add terminating zero. + data.writeBEInt(0); + left = 5; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + left = 7; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + left = 6; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + // Add two additional strings data.add("test2"); data.writeBEInt(0); data.add("test3"); data.writeBEInt(0); - uint64_t pos = 5; - uint64_t left = 3 * 6; + pos = 5; + left = 3 * 6 - 1; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); + left = 3 * 6 + 1; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationNeedMoreData); + left = 3 * 6; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 5 + 3 * 6); + ASSERT_THAT(left, 0); + pos = 5; + left = 3 * 6; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 5 + 3 * 6); ASSERT_THAT(left, 0); @@ -401,46 +504,6 @@ TEST(Repeated, BasicTestWithStrings) { ASSERT_TRUE(out.find("test3") != std::string::npos); } -// Test verifies that read fails when there is less -// bytes in the buffer than bytes needed to read to the end of the message. -TEST(Repeated, NotEnoughData) { - Repeated field; - - Buffer::OwnedImpl data; - // Write some data to simulate message header. - // It will be ignored. - data.writeBEInt(101); - data.writeBEInt(102); - data.add("test"); - - // "test" with terminating zero is 5 bytes. - // Set "left" to indicate that 6 bytes are needed. - uint64_t pos = 5; - uint64_t left = 5 + 6; - ASSERT_FALSE(field.read(data, pos, left)); -} - -// Test verifies that entire read fails when one of -// subordinate reads fails. -TEST(Repeated, NotEnoughDataForSecondString) { - Repeated field; - - Buffer::OwnedImpl data; - // Write some data to simulate message header. - // It will be ignored. - data.writeBEInt(101); - data.writeBEInt(102); - // Now write 3 strings. Each terminated by zero byte. - data.add("test1"); - data.writeBEInt(0); - data.add("test2"); - // Do not write terminating zero. - // Read should fail here. - uint64_t pos = 5; - uint64_t left = 6 + 5; - ASSERT_FALSE(field.read(data, pos, left)); -} - // Sequence composite type tests. TEST(Sequence, Int32SingleValue) { Sequence field; @@ -450,6 +513,11 @@ TEST(Sequence, Int32SingleValue) { uint64_t pos = 0; uint64_t left = 4; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 4); + ASSERT_THAT(left, 0); + pos = 0; + left = 4; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 4); ASSERT_THAT(left, 0); @@ -466,6 +534,11 @@ TEST(Sequence, Int16SingleValue) { uint64_t pos = 0; uint64_t left = 2; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 2); + ASSERT_THAT(left, 0); + pos = 0; + left = 2; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 2); ASSERT_THAT(left, 0); @@ -484,6 +557,11 @@ TEST(Sequence, BasicMultipleValues1) { uint64_t pos = 0; uint64_t left = 4 + 5; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, 4 + 5); + ASSERT_THAT(left, 0); + pos = 0; + left = 4 + 5; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, 4 + 5); ASSERT_THAT(left, 0); @@ -503,6 +581,11 @@ TEST(Sequence, BasicMultipleValues2) { uint64_t pos = 0; uint64_t left = 4 + 2; uint64_t expected_pos = left; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, expected_pos); + ASSERT_THAT(left, 0); + pos = 0; + left = 4 + 2; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, expected_pos); ASSERT_THAT(left, 0); @@ -524,6 +607,11 @@ TEST(Sequence, BasicMultipleValues3) { uint64_t pos = 0; uint64_t left = 4 + 2 + 4 + 2; uint64_t expected_pos = left; + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationOK); + ASSERT_THAT(pos, expected_pos); + ASSERT_THAT(left, 0); + pos = 0; + left = 4 + 2 + 4 + 2; ASSERT_TRUE(field.read(data, pos, left)); ASSERT_THAT(pos, expected_pos); ASSERT_THAT(left, 0); @@ -547,7 +635,7 @@ TEST(Sequence, NotEnoughData) { uint64_t pos = 0; uint64_t left = 4 + 4; - ASSERT_FALSE(field.read(data, pos, left)); + ASSERT_THAT(field.validate(data, 0, pos, left), Message::ValidationFailed); } // Tests for Message interface and helper function createMsgBodyReader. @@ -555,7 +643,16 @@ TEST(PostgresMessage, SingleFieldInt32) { std::unique_ptr msg = createMsgBodyReader(); Buffer::OwnedImpl data; + // Validation of empty message should complain that there + // is not enough data in the buffer. + ASSERT_THAT(msg->validate(data, 0, 4), Message::ValidationNeedMoreData); + data.writeBEInt(12); + + // Simulate that message is longer than In32. + ASSERT_THAT(msg->validate(data, 0, 5), Message::ValidationFailed); + + ASSERT_THAT(msg->validate(data, 0, 4), Message::ValidationOK); ASSERT_TRUE(msg->read(data, 4)); auto out = msg->toString(); ASSERT_THAT(out, "[12]"); @@ -565,7 +662,13 @@ TEST(PostgresMessage, SingleFieldInt16) { std::unique_ptr msg = createMsgBodyReader(); Buffer::OwnedImpl data; + + // Validation of empty message should complain that there + // is not enough data in the buffer. + ASSERT_THAT(msg->validate(data, 0, 2), Message::ValidationNeedMoreData); + data.writeBEInt(12); + ASSERT_THAT(msg->validate(data, 0, 2), Message::ValidationOK); ASSERT_TRUE(msg->read(data, 2)); auto out = msg->toString(); ASSERT_THAT(out, "[12]"); @@ -575,12 +678,18 @@ TEST(PostgresMessage, SingleByteN) { std::unique_ptr msg = createMsgBodyReader(); Buffer::OwnedImpl data; + // Validation of empty message should complain that there + // is not enough data in the buffer. + ASSERT_THAT(msg->validate(data, 0, 4), Message::ValidationNeedMoreData); + data.writeBEInt(0); data.writeBEInt(1); data.writeBEInt(2); data.writeBEInt(3); data.writeBEInt(4); - ASSERT_TRUE(msg->read(data, 5 * 1)); + const uint64_t length = 5 * 1; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("0") != std::string::npos); // NOLINT ASSERT_TRUE(out.find("1") != std::string::npos); // NOLINT @@ -593,9 +702,16 @@ TEST(PostgresMessage, MultipleValues1) { std::unique_ptr msg = createMsgBodyReader(); Buffer::OwnedImpl data; + + // Validation of empty message should complain that there + // is not enough data in the buffer. + ASSERT_THAT(msg->validate(data, 0, 4), Message::ValidationNeedMoreData); + data.writeBEInt(12); data.writeBEInt(13); - ASSERT_TRUE(msg->read(data, 4 + 2)); + const uint64_t length = 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("12") != std::string::npos); ASSERT_TRUE(out.find("13") != std::string::npos); @@ -608,7 +724,9 @@ TEST(PostgresMessage, MultipleValues2) { data.writeBEInt(13); data.writeBEInt(14); data.writeBEInt(15); - ASSERT_TRUE(msg->read(data, 2 + 4 + 2)); + const uint64_t length = 2 + 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("13") != std::string::npos); ASSERT_TRUE(out.find("14") != std::string::npos); @@ -623,7 +741,9 @@ TEST(PostgresMessage, MultipleValues3) { data.writeBEInt(13); data.writeBEInt(14); data.writeBEInt(15); - ASSERT_TRUE(msg->read(data, 4 + 2 + 4 + 2)); + const uint64_t length = 4 + 2 + 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("12") != std::string::npos); ASSERT_TRUE(out.find("13") != std::string::npos); @@ -640,7 +760,9 @@ TEST(PostgresMessage, MultipleValues4) { data.writeBEInt(15); data.writeBEInt(16); data.writeBEInt(17); - ASSERT_TRUE(msg->read(data, 2 + 4 + 2 + 4 + 2)); + const uint64_t length = 2 + 4 + 2 + 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("13") != std::string::npos); ASSERT_TRUE(out.find("14") != std::string::npos); @@ -659,7 +781,9 @@ TEST(PostgresMessage, MultipleValues5) { data.writeBEInt(15); data.writeBEInt(16); data.writeBEInt(17); - ASSERT_TRUE(msg->read(data, 4 + 2 + 4 + 2 + 4 + 2)); + const uint64_t length = 4 + 2 + 4 + 2 + 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("12") != std::string::npos); ASSERT_TRUE(out.find("13") != std::string::npos); @@ -682,7 +806,9 @@ TEST(PostgresMessage, MultipleValues6) { data.writeBEInt(15); data.writeBEInt(16); data.writeBEInt(17); - ASSERT_TRUE(msg->read(data, 5 + 4 + 2 + 4 + 2 + 4 + 2)); + const uint64_t length = 5 + 4 + 2 + 4 + 2 + 4 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("test") != std::string::npos); ASSERT_TRUE(out.find("12") != std::string::npos); @@ -705,7 +831,9 @@ TEST(PostgresMessage, MultipleValues7) { data.writeBEInt(13); data.writeBEInt(14); data.writeBEInt(15); - ASSERT_TRUE(msg->read(data, 5 + 2 + 3 * 4)); + const uint64_t length = 5 + 2 + 3 * 4; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("test") != std::string::npos); ASSERT_TRUE(out.find("13") != std::string::npos); @@ -722,7 +850,9 @@ TEST(PostgresMessage, ArraySet1) { data.writeBEInt(13); data.writeBEInt(14); data.writeBEInt(15); - ASSERT_TRUE(msg->read(data, 2 + 3 * 2)); + const uint64_t length = 2 + 3 * 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("13") != std::string::npos); ASSERT_TRUE(out.find("14") != std::string::npos); @@ -745,8 +875,9 @@ TEST(PostgresMessage, ArraySet2) { // 16-bits value. data.writeBEInt(115); - - ASSERT_TRUE(msg->read(data, 2 + 4 + 5 + 2)); + const uint64_t length = 2 + 4 + 5 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("114") != std::string::npos); ASSERT_TRUE(out.find("115") != std::string::npos); @@ -774,8 +905,9 @@ TEST(PostgresMessage, ArraySet3) { // 16-bits value. data.writeBEInt(115); - - ASSERT_TRUE(msg->read(data, 2 + 3 * 2 + 2 + 4 + 5 + 2)); + const uint64_t length = 2 + 3 * 2 + 2 + 4 + 5 + 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("13") != std::string::npos); ASSERT_TRUE(out.find("115") != std::string::npos); @@ -799,8 +931,9 @@ TEST(PostgresMessage, ArraySet4) { data.writeBEInt(2); data.writeBEInt(113); data.writeBEInt(114); - - ASSERT_TRUE(msg->read(data, 2 + 4 + 5 + 2 + 2 * 2)); + const uint64_t length = 2 + 4 + 5 + 2 + 2 * 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("111") != std::string::npos); ASSERT_TRUE(out.find("114") != std::string::npos); @@ -830,8 +963,9 @@ TEST(PostgresMessage, ArraySet5) { data.writeBEInt(2); data.writeBEInt(113); data.writeBEInt(114); - - ASSERT_TRUE(msg->read(data, 2 + 3 * 2 + 2 + 4 + 5 + 2 + 2 * 2)); + const uint64_t length = 2 + 3 * 2 + 2 + 4 + 5 + 2 + 2 * 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("13") != std::string::npos); ASSERT_TRUE(out.find("114") != std::string::npos); @@ -867,7 +1001,9 @@ TEST(PostgresMessage, ArraySet6) { data.writeBEInt(113); data.writeBEInt(114); - ASSERT_TRUE(msg->read(data, 5 + 2 + 3 * 2 + 2 + 4 + 5 + 2 + 2 * 2)); + const uint64_t length = 5 + 2 + 3 * 2 + 2 + 4 + 5 + 2 + 2 * 2; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("test") != std::string::npos); ASSERT_TRUE(out.find("13") != std::string::npos); @@ -886,7 +1022,9 @@ TEST(PostgresMessage, Repeated1) { data.add("test3"); data.writeBEInt(0); - ASSERT_TRUE(msg->read(data, 3 * 6)); + const uint64_t length = 3 * 6; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("test1") != std::string::npos); ASSERT_TRUE(out.find("test2") != std::string::npos); @@ -906,7 +1044,9 @@ TEST(PostgresMessage, Repeated2) { data.add("test3"); data.writeBEInt(0); - ASSERT_TRUE(msg->read(data, 4 + 3 * 6)); + const uint64_t length = 4 + 3 * 6; + ASSERT_THAT(msg->validate(data, 0, length), Message::ValidationOK); + ASSERT_TRUE(msg->read(data, length)); auto out = msg->toString(); ASSERT_TRUE(out.find("115") != std::string::npos); ASSERT_TRUE(out.find("test1") != std::string::npos); @@ -922,7 +1062,27 @@ TEST(PostgresMessage, NotEnoughData) { data.writeBEInt(1); data.writeBEInt(2); - ASSERT_FALSE(msg->read(data, 3)); + ASSERT_THAT(msg->validate(data, 0, 4), Message::ValidationNeedMoreData); + ASSERT_THAT(msg->validate(data, 0, 2), Message::ValidationFailed); +} + +// Test checks validating a properly formatted message +// which starts at some offset in data buffer. +TEST(PostgresMessage, ValidateFromOffset) { + std::unique_ptr msg = createMsgBodyReader(); + Buffer::OwnedImpl data; + + // Write some data which should be skipped by validator. + data.add("skip"); + data.writeBEInt(0); + + // Write valid data according to message syntax. + data.writeBEInt(110); + data.add("test123"); + data.writeBEInt(0); + + // Skip first 5 bytes in the buffer. + ASSERT_THAT(msg->validate(data, 5, 4 + 8), Message::ValidationOK); } } // namespace PostgresProxy diff --git a/test/extensions/filters/network/postgres_proxy/postgres_test_utils.cc b/test/extensions/filters/network/postgres_proxy/postgres_test_utils.cc index 51819b1ad72e2..56450f5c5fe98 100644 --- a/test/extensions/filters/network/postgres_proxy/postgres_test_utils.cc +++ b/test/extensions/filters/network/postgres_proxy/postgres_test_utils.cc @@ -10,8 +10,11 @@ void createPostgresMsg(Buffer::Instance& data, std::string type, std::string pay data.drain(data.length()); ASSERT(1 == type.length()); data.add(type); - data.writeBEInt(4 + payload.length()); - data.add(payload); + data.writeBEInt(4 + (payload.empty() ? 0 : (payload.length() + 1))); + if (!payload.empty()) { + data.add(payload); + data.writeBEInt(0); + } } } // namespace PostgresProxy