diff --git a/include/envoy/stream_info/stream_info.h b/include/envoy/stream_info/stream_info.h index dc4a679739767..99e92b3a72412 100644 --- a/include/envoy/stream_info/stream_info.h +++ b/include/envoy/stream_info/stream_info.h @@ -236,6 +236,11 @@ class StreamInfo { */ virtual void setResponseFlag(ResponseFlag response_flag) PURE; + /** + * @param code the HTTP response code to set for this request. + */ + virtual void setResponseCode(uint32_t code) PURE; + /** * @param rc_details the response code details string to set for this request. * See ResponseCodeDetailValues above for well-known constants. diff --git a/source/common/local_reply/local_reply.cc b/source/common/local_reply/local_reply.cc index a960ffa4e952e..74cd52fabb650 100644 --- a/source/common/local_reply/local_reply.cc +++ b/source/common/local_reply/local_reply.cc @@ -74,7 +74,7 @@ class ResponseMapper { bool matchAndRewrite(const Http::RequestHeaderMap& request_headers, Http::ResponseHeaderMap& response_headers, const Http::ResponseTrailerMap& response_trailers, - StreamInfo::StreamInfoImpl& stream_info, Http::Code& code, std::string& body, + StreamInfo::StreamInfo& stream_info, Http::Code& code, std::string& body, BodyFormatter*& final_formatter) const { // If not matched, just bail out. if (!filter_->evaluate(stream_info, request_headers, response_headers, response_trailers)) { @@ -90,7 +90,7 @@ class ResponseMapper { if (status_code_.has_value() && code != status_code_.value()) { code = status_code_.value(); response_headers.setStatus(std::to_string(enumToInt(code))); - stream_info.response_code_ = static_cast(code); + stream_info.setResponseCode(static_cast(code)); } if (body_formatter_) { @@ -126,14 +126,14 @@ class LocalReplyImpl : public LocalReply { } void rewrite(const Http::RequestHeaderMap* request_headers, - Http::ResponseHeaderMap& response_headers, StreamInfo::StreamInfoImpl& stream_info, + Http::ResponseHeaderMap& response_headers, StreamInfo::StreamInfo& stream_info, Http::Code& code, std::string& body, absl::string_view& content_type) const override { // Set response code to stream_info and response_headers due to: // 1) StatusCode filter is using response_code from stream_info, // 2) %RESP(:status)% is from Status() in response_headers. response_headers.setStatus(std::to_string(enumToInt(code))); - stream_info.response_code_ = static_cast(code); + stream_info.setResponseCode(static_cast(code)); if (request_headers == nullptr) { request_headers = Http::StaticEmptyHeaders::get().request_headers.get(); diff --git a/source/common/local_reply/local_reply.h b/source/common/local_reply/local_reply.h index cafcaf33d3079..5db93caa07fda 100644 --- a/source/common/local_reply/local_reply.h +++ b/source/common/local_reply/local_reply.h @@ -24,7 +24,7 @@ class LocalReply { */ virtual void rewrite(const Http::RequestHeaderMap* request_headers, Http::ResponseHeaderMap& response_headers, - StreamInfo::StreamInfoImpl& stream_info, Http::Code& code, std::string& body, + StreamInfo::StreamInfo& stream_info, Http::Code& code, std::string& body, absl::string_view& content_type) const PURE; }; diff --git a/source/common/stream_info/stream_info_impl.h b/source/common/stream_info/stream_info_impl.h index 1c0614d93cbf1..4f37abe3fbf6b 100644 --- a/source/common/stream_info/stream_info_impl.h +++ b/source/common/stream_info/stream_info_impl.h @@ -131,6 +131,8 @@ struct StreamInfoImpl : public StreamInfo { return response_code_details_; } + void setResponseCode(uint32_t code) override { response_code_ = code; } + void setResponseCodeDetails(absl::string_view rc_details) override { response_code_details_.emplace(absl::StrReplaceAll(rc_details, emptySpaceReplacement())); } diff --git a/test/common/stream_info/test_util.h b/test/common/stream_info/test_util.h index 560b485e18c07..685607f876c32 100644 --- a/test/common/stream_info/test_util.h +++ b/test/common/stream_info/test_util.h @@ -38,6 +38,7 @@ class TestStreamInfo : public StreamInfo::StreamInfo { const absl::optional& responseCodeDetails() const override { return response_code_details_; } + void setResponseCode(uint32_t code) override { response_code_ = code; } void setResponseCodeDetails(absl::string_view rc_details) override { response_code_details_.emplace(rc_details); } diff --git a/test/mocks/local_reply/mocks.h b/test/mocks/local_reply/mocks.h index 3d0a7ddeab882..913f815d50691 100644 --- a/test/mocks/local_reply/mocks.h +++ b/test/mocks/local_reply/mocks.h @@ -11,9 +11,9 @@ class MockLocalReply : public LocalReply { MOCK_METHOD(void, rewrite, (const Http::RequestHeaderMap* request_headers, - Http::ResponseHeaderMap& response_headers, StreamInfo::StreamInfoImpl& stream_info, + Http::ResponseHeaderMap& response_headers, StreamInfo::StreamInfo& stream_info, Http::Code& code, std::string& body, absl::string_view& content_type), (const)); }; } // namespace LocalReply -} // namespace Envoy \ No newline at end of file +} // namespace Envoy diff --git a/test/mocks/stream_info/mocks.cc b/test/mocks/stream_info/mocks.cc index 8373bdeb36032..066734a71507b 100644 --- a/test/mocks/stream_info/mocks.cc +++ b/test/mocks/stream_info/mocks.cc @@ -24,6 +24,9 @@ MockStreamInfo::MockStreamInfo() ON_CALL(*this, setResponseFlag(_)).WillByDefault(Invoke([this](ResponseFlag response_flag) { response_flags_ |= response_flag; })); + ON_CALL(*this, setResponseCode(_)).WillByDefault(Invoke([this](uint32_t code) { + response_code_ = code; + })); ON_CALL(*this, setResponseCodeDetails(_)).WillByDefault(Invoke([this](absl::string_view details) { response_code_details_ = std::string(details); })); diff --git a/test/mocks/stream_info/mocks.h b/test/mocks/stream_info/mocks.h index b02b849c2310b..9ed11c0afa966 100644 --- a/test/mocks/stream_info/mocks.h +++ b/test/mocks/stream_info/mocks.h @@ -21,6 +21,7 @@ class MockStreamInfo : public StreamInfo { // StreamInfo::StreamInfo MOCK_METHOD(void, setResponseFlag, (ResponseFlag response_flag)); + MOCK_METHOD(void, setResponseCode, (uint32_t)); MOCK_METHOD(void, setResponseCodeDetails, (absl::string_view)); MOCK_METHOD(void, setConnectionTerminationDetails, (absl::string_view)); MOCK_METHOD(bool, intersectResponseFlags, (uint64_t), (const));