diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index 748119568bca7..fa75e27fa82e8 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -657,10 +657,16 @@ Http::HeaderMap* Context::getMap(WasmHeaderMapType type) { case WasmHeaderMapType::RequestHeaders: return request_headers_; case WasmHeaderMapType::RequestTrailers: + if (request_trailers_ == nullptr && request_body_buffer_ && end_of_stream_) { + request_trailers_ = &decoder_callbacks_->addDecodedTrailers(); + } return request_trailers_; case WasmHeaderMapType::ResponseHeaders: return response_headers_; case WasmHeaderMapType::ResponseTrailers: + if (response_trailers_ == nullptr && response_body_buffer_ && end_of_stream_) { + response_trailers_ = &encoder_callbacks_->addEncodedTrailers(); + } return response_trailers_; default: return nullptr; diff --git a/test/extensions/filters/http/wasm/test_data/headers_rust.rs b/test/extensions/filters/http/wasm/test_data/headers_rust.rs index c2aee58f7b0ba..1fbc30680e5eb 100644 --- a/test/extensions/filters/http/wasm/test_data/headers_rust.rs +++ b/test/extensions/filters/http/wasm/test_data/headers_rust.rs @@ -1,4 +1,4 @@ -use log::{trace, debug, error, info, warn}; +use log::{debug, error, info, trace, warn}; use proxy_wasm::traits::{Context, HttpContext}; use proxy_wasm::types::*; @@ -49,10 +49,13 @@ impl HttpContext for TestStream { action } - fn on_http_request_body(&mut self, body_size: usize, _: bool) -> Action { + fn on_http_request_body(&mut self, body_size: usize, end_of_stream: bool) -> Action { if let Some(body) = self.get_http_request_body(0, body_size) { error!("onBody {}", String::from_utf8(body).unwrap()); } + if end_of_stream { + self.add_http_request_trailer("newtrailer", "request"); + } Action::Continue } @@ -61,6 +64,13 @@ impl HttpContext for TestStream { Action::Continue } + fn on_http_response_body(&mut self, _: usize, end_of_stream: bool) -> Action { + if end_of_stream { + self.add_http_response_trailer("newtrailer", "response"); + } + Action::Continue + } + fn on_http_response_trailers(&mut self, _: usize) -> Action { Action::Pause } diff --git a/test/extensions/filters/http/wasm/test_data/test_cpp.cc b/test/extensions/filters/http/wasm/test_data/test_cpp.cc index 705e468dce210..69f19e2c8b53d 100644 --- a/test/extensions/filters/http/wasm/test_data/test_cpp.cc +++ b/test/extensions/filters/http/wasm/test_data/test_cpp.cc @@ -38,6 +38,7 @@ class TestContext : public Context { FilterHeadersStatus onResponseHeaders(uint32_t, bool) override; FilterTrailersStatus onResponseTrailers(uint32_t) override; FilterDataStatus onRequestBody(size_t body_buffer_length, bool end_of_stream) override; + FilterDataStatus onResponseBody(size_t body_buffer_length, bool end_of_stream) override; void onLog() override; void onDone() override; @@ -306,11 +307,14 @@ FilterTrailersStatus TestContext::onResponseTrailers(uint32_t) { return FilterTrailersStatus::StopIteration; } -FilterDataStatus TestContext::onRequestBody(size_t body_buffer_length, bool) { +FilterDataStatus TestContext::onRequestBody(size_t body_buffer_length, bool end_of_stream) { auto test = root()->test_; if (test == "headers") { auto body = getBufferBytes(WasmBufferType::HttpRequestBody, 0, body_buffer_length); logError(std::string("onBody ") + std::string(body->view())); + if (end_of_stream) { + CHECK_RESULT(addRequestTrailer("newtrailer", "request")); + } } else if (test == "metadata") { std::string value; if (!getValue({"node", "metadata", "wasm_node_get_key"}, &value)) { @@ -335,6 +339,16 @@ FilterDataStatus TestContext::onRequestBody(size_t body_buffer_length, bool) { return FilterDataStatus::Continue; } +FilterDataStatus TestContext::onResponseBody(size_t, bool end_of_stream) { + auto test = root()->test_; + if (test == "headers") { + if (end_of_stream) { + CHECK_RESULT(addResponseTrailer("newtrailer", "response")); + } + } + return FilterDataStatus::Continue; +} + void TestContext::onLog() { auto test = root()->test_; if (test == "headers") { @@ -351,10 +365,10 @@ void TestContext::onLog() { logWarn("response bogus-trailer found"); } } else if (test == "cluster_metadata") { - std::string cluster_metadata; - if (getValue({"cluster_metadata", "filter_metadata", "namespace", "key"}, &cluster_metadata)) { - logWarn("cluster metadata: " + cluster_metadata); - } + std::string cluster_metadata; + if (getValue({"cluster_metadata", "filter_metadata", "namespace", "key"}, &cluster_metadata)) { + logWarn("cluster metadata: " + cluster_metadata); + } } else if (test == "property") { setFilterState("wasm_state", "wasm_value"); auto path = getRequestHeader(":path"); diff --git a/test/extensions/filters/http/wasm/wasm_filter_test.cc b/test/extensions/filters/http/wasm/wasm_filter_test.cc index 2e484228b1fd1..adfff9ce26f4d 100644 --- a/test/extensions/filters/http/wasm/wasm_filter_test.cc +++ b/test/extensions/filters/http/wasm/wasm_filter_test.cc @@ -191,6 +191,43 @@ TEST_P(WasmHttpFilterTest, AllHeadersAndTrailers) { filter().onDestroy(); } +TEST_P(WasmHttpFilterTest, AddTrailers) { + setupTest("", "headers"); + setupFilter(); + EXPECT_CALL(encoder_callbacks_, streamInfo()).WillRepeatedly(ReturnRef(request_stream_info_)); + EXPECT_CALL(filter(), + log_(spdlog::level::debug, Eq(absl::string_view("onRequestHeaders 2 headers")))); + EXPECT_CALL(filter(), log_(spdlog::level::info, Eq(absl::string_view("header path /")))); + EXPECT_CALL(filter(), log_(spdlog::level::err, Eq(absl::string_view("onBody data")))).Times(2); + EXPECT_CALL(filter(), log_(spdlog::level::warn, Eq(absl::string_view("onDone 2")))); + + Http::TestRequestHeaderMapImpl request_headers{{":path", "/"}, {"server", "envoy"}}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, false)); + EXPECT_THAT(request_headers.get_("newheader"), Eq("newheadervalue")); + EXPECT_THAT(request_headers.get_("server"), Eq("envoy-wasm")); + + Buffer::OwnedImpl data("data"); + Http::TestRequestTrailerMapImpl request_trailers{}; + EXPECT_CALL(decoder_callbacks_, addDecodedTrailers()).Times(0); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter().decodeData(data, false)); + EXPECT_CALL(decoder_callbacks_, addDecodedTrailers()).WillOnce(ReturnRef(request_trailers)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter().decodeData(data, true)); + EXPECT_THAT(request_trailers.get_("newtrailer"), Eq("request")); + + Http::TestResponseHeaderMapImpl response_headers{}; + EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().encodeHeaders(response_headers, false)); + EXPECT_THAT(response_headers.get_("test-status"), Eq("OK")); + + Http::TestResponseTrailerMapImpl response_trailers{}; + EXPECT_CALL(encoder_callbacks_, addEncodedTrailers()).Times(0); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter().encodeData(data, false)); + EXPECT_CALL(encoder_callbacks_, addEncodedTrailers()).WillOnce(ReturnRef(response_trailers)); + EXPECT_EQ(Http::FilterDataStatus::Continue, filter().encodeData(data, true)); + EXPECT_THAT(response_trailers.get_("newtrailer"), Eq("response")); + + filter().onDestroy(); +} + TEST_P(WasmHttpFilterTest, AllHeadersAndTrailersNotStarted) { setupTest("", "headers"); setupFilter(); @@ -224,7 +261,10 @@ TEST_P(WasmHttpFilterTest, HeadersOnlyRequestHeadersAndBody) { EXPECT_EQ(Http::FilterHeadersStatus::Continue, filter().decodeHeaders(request_headers, false)); EXPECT_FALSE(filter().endOfStream(proxy_wasm::WasmStreamType::Request)); Buffer::OwnedImpl data("hello"); + Http::TestRequestTrailerMapImpl request_trailers{}; + EXPECT_CALL(decoder_callbacks_, addDecodedTrailers()).WillOnce(ReturnRef(request_trailers)); EXPECT_EQ(Http::FilterDataStatus::Continue, filter().decodeData(data, true)); + EXPECT_THAT(request_trailers.get_("newtrailer"), Eq("request")); filter().onDestroy(); }