diff --git a/source/extensions/common/wasm/context.cc b/source/extensions/common/wasm/context.cc index e6e4f8ae0f057..86221f36447ab 100644 --- a/source/extensions/common/wasm/context.cc +++ b/source/extensions/common/wasm/context.cc @@ -1568,6 +1568,7 @@ WasmResult Context::sendLocalResponse(uint32_t response_code, absl::string_view details = std::string(details)] { decoder_callbacks_->sendLocalReply(static_cast(response_code), body_text, modify_headers, grpc_status, details); + http_local_response_sent_ = true; }); } return WasmResult::Ok; @@ -1579,9 +1580,12 @@ Http::FilterHeadersStatus Context::decodeHeaders(Http::RequestHeaderMap& headers request_headers_ = &headers; end_of_stream_ = end_stream; auto result = convertFilterHeadersStatus(onRequestHeaders(headerSize(&headers), end_stream)); - if (result == Http::FilterHeadersStatus::Continue) { + if (http_local_response_sent_) { + return Http::FilterHeadersStatus::StopIteration; + } else if (result == Http::FilterHeadersStatus::Continue) { request_headers_ = nullptr; } + return result; } @@ -1594,6 +1598,9 @@ Http::FilterDataStatus Context::decodeData(::Envoy::Buffer::Instance& data, bool const auto buffer = getBuffer(WasmBufferType::HttpRequestBody); const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size(); auto result = convertFilterDataStatus(onRequestBody(buffer_size, end_stream)); + if (http_local_response_sent_) { + return Http::FilterDataStatus::StopIterationNoBuffer; + } buffering_request_body_ = false; switch (result) { case Http::FilterDataStatus::Continue: @@ -1615,7 +1622,9 @@ Http::FilterTrailersStatus Context::decodeTrailers(Http::RequestTrailerMap& trai } request_trailers_ = &trailers; auto result = convertFilterTrailersStatus(onRequestTrailers(headerSize(&trailers))); - if (result == Http::FilterTrailersStatus::Continue) { + if (http_local_response_sent_) { + return Http::FilterTrailersStatus::StopIteration; + } else if (result == Http::FilterTrailersStatus::Continue) { request_trailers_ = nullptr; } return result; @@ -1649,7 +1658,9 @@ Http::FilterHeadersStatus Context::encodeHeaders(Http::ResponseHeaderMap& header response_headers_ = &headers; end_of_stream_ = end_stream; auto result = convertFilterHeadersStatus(onResponseHeaders(headerSize(&headers), end_stream)); - if (result == Http::FilterHeadersStatus::Continue) { + if (http_local_response_sent_) { + return Http::FilterHeadersStatus::StopIteration; + } else if (result == Http::FilterHeadersStatus::Continue) { response_headers_ = nullptr; } return result; @@ -1664,6 +1675,10 @@ Http::FilterDataStatus Context::encodeData(::Envoy::Buffer::Instance& data, bool const auto buffer = getBuffer(WasmBufferType::HttpResponseBody); const auto buffer_size = (buffer == nullptr) ? 0 : buffer->size(); auto result = convertFilterDataStatus(onResponseBody(buffer_size, end_stream)); + if (http_local_response_sent_) { + return Http::FilterDataStatus::StopIterationNoBuffer; + } + buffering_response_body_ = false; switch (result) { case Http::FilterDataStatus::Continue: @@ -1685,9 +1700,13 @@ Http::FilterTrailersStatus Context::encodeTrailers(Http::ResponseTrailerMap& tra } response_trailers_ = &trailers; auto result = convertFilterTrailersStatus(onResponseTrailers(headerSize(&trailers))); + if (http_local_response_sent_) { + return Http::FilterTrailersStatus::StopIteration; + } if (result == Http::FilterTrailersStatus::Continue) { response_trailers_ = nullptr; } + return result; } diff --git a/source/extensions/common/wasm/context.h b/source/extensions/common/wasm/context.h index e288c1e506024..6f77b907cbc85 100644 --- a/source/extensions/common/wasm/context.h +++ b/source/extensions/common/wasm/context.h @@ -425,6 +425,7 @@ class Context : public proxy_wasm::ContextBase, // HTTP filter state. bool http_request_started_ = false; // When decodeHeaders() is called the request is "started". + bool http_local_response_sent_ = false; // indicates if the local response is sent Http::RequestHeaderMap* request_headers_{}; Http::ResponseHeaderMap* response_headers_{}; ::Envoy::Buffer::Instance* request_body_buffer_{};