From 65569bf297848d185d3ef0aface27e30530604e1 Mon Sep 17 00:00:00 2001 From: RRobot-lm Date: Thu, 22 Dec 2022 09:54:51 +0800 Subject: [PATCH] Implement Add method for header Signed-off-by: RRobot-lm --- pkg/api/capi.go | 2 +- pkg/http/api.h | 7 ++- pkg/http/capi.go | 10 +++- pkg/http/type.go | 15 +++--- src/envoy/common/dso/api.h | 7 ++- src/envoy/http/golang/cgo.cc | 6 +-- src/envoy/http/golang/golang_filter.cc | 20 ++++++-- src/envoy/http/golang/golang_filter.h | 2 +- test/http/golang/golang_integration_test.cc | 54 +++++++++++++++++++++ test/http/golang/test_data/basic/filter.go | 52 +++++++++++++++++++- 10 files changed, 156 insertions(+), 19 deletions(-) diff --git a/pkg/api/capi.go b/pkg/api/capi.go index 704228692f..b59cf6fabe 100644 --- a/pkg/api/capi.go +++ b/pkg/api/capi.go @@ -26,7 +26,7 @@ type HttpCAPI interface { // experience api, memory unsafe HttpGetHeader(r unsafe.Pointer, key *string, value *string) HttpCopyHeaders(r unsafe.Pointer, num uint64, bytes uint64) map[string][]string - HttpSetHeader(r unsafe.Pointer, key *string, value *string) + HttpSetHeader(r unsafe.Pointer, key *string, value *string, add bool) HttpRemoveHeader(r unsafe.Pointer, key *string) HttpGetBuffer(r unsafe.Pointer, bufferPtr uint64, value *string, length uint64) diff --git a/pkg/http/api.h b/pkg/http/api.h index 6e5fdffe14..4561a0f0dc 100644 --- a/pkg/http/api.h +++ b/pkg/http/api.h @@ -22,13 +22,18 @@ typedef enum { Prepend, } bufferAction; +typedef enum { + HeaderSet, + HeaderAdd, +} headerAction; + int moeHttpContinue(void* r, int status); int moeHttpSendLocalReply(void* r, int response_code, void* body_text, void* headers, long long int grpc_status, void* details); int moeHttpGetHeader(void* r, void* key, void* value); int moeHttpCopyHeaders(void* r, void* strs, void* buf); -int moeHttpSetHeader(void* r, void* key, void* value); +int moeHttpSetHeaderHelper(void* r, void* key, void* value, headerAction action); int moeHttpRemoveHeader(void* r, void* key); int moeHttpGetBuffer(void* r, unsigned long long int buffer, void* value); diff --git a/pkg/http/capi.go b/pkg/http/capi.go index e6104d87a2..326aa2766f 100644 --- a/pkg/http/capi.go +++ b/pkg/http/capi.go @@ -112,8 +112,14 @@ func (c *httpCApiImpl) HttpCopyHeaders(r unsafe.Pointer, num uint64, bytes uint6 return m } -func (c *httpCApiImpl) HttpSetHeader(r unsafe.Pointer, key *string, value *string) { - res := C.moeHttpSetHeader(r, unsafe.Pointer(key), unsafe.Pointer(value)) +func (c *httpCApiImpl) HttpSetHeader(r unsafe.Pointer, key *string, value *string, add bool) { + var act C.headerAction + if add { + act = C.HeaderAdd + } else { + act = C.HeaderSet + } + res := C.moeHttpSetHeaderHelper(r, unsafe.Pointer(key), unsafe.Pointer(value), act) handleCApiStatus(res) } diff --git a/pkg/http/type.go b/pkg/http/type.go index b92978316c..7ec0b49b1d 100644 --- a/pkg/http/type.go +++ b/pkg/http/type.go @@ -75,11 +75,18 @@ func (h *headerMapImpl) Set(key, value string) { if h.headers != nil { h.headers[key] = []string{value} } - cAPI.HttpSetHeader(unsafe.Pointer(h.request.req), &key, &value) + cAPI.HttpSetHeader(unsafe.Pointer(h.request.req), &key, &value, false) } func (h *headerMapImpl) Add(key, value string) { - // TODO: add + if h.headers != nil { + if hdrs, found := h.headers[key]; found { + h.headers[key] = append(hdrs, value) + } else { + h.headers[key] = []string{value} + } + } + cAPI.HttpSetHeader(unsafe.Pointer(h.request.req), &key, &value, true) } func (h *headerMapImpl) Del(key string) { @@ -140,10 +147,6 @@ func (h *responseHeaderMapImpl) GetRaw(key string) string { return value } -func (h *responseHeaderMapImpl) Add(key, value string) { - // TODO: add -} - func (h *responseHeaderMapImpl) Del(key string) { if h.headers != nil { delete(h.headers, key) diff --git a/src/envoy/common/dso/api.h b/src/envoy/common/dso/api.h index 6e5fdffe14..4561a0f0dc 100644 --- a/src/envoy/common/dso/api.h +++ b/src/envoy/common/dso/api.h @@ -22,13 +22,18 @@ typedef enum { Prepend, } bufferAction; +typedef enum { + HeaderSet, + HeaderAdd, +} headerAction; + int moeHttpContinue(void* r, int status); int moeHttpSendLocalReply(void* r, int response_code, void* body_text, void* headers, long long int grpc_status, void* details); int moeHttpGetHeader(void* r, void* key, void* value); int moeHttpCopyHeaders(void* r, void* strs, void* buf); -int moeHttpSetHeader(void* r, void* key, void* value); +int moeHttpSetHeaderHelper(void* r, void* key, void* value, headerAction action); int moeHttpRemoveHeader(void* r, void* key); int moeHttpGetBuffer(void* r, unsigned long long int buffer, void* value); diff --git a/src/envoy/http/golang/cgo.cc b/src/envoy/http/golang/cgo.cc index cb79af6d42..2a35eb859d 100644 --- a/src/envoy/http/golang/cgo.cc +++ b/src/envoy/http/golang/cgo.cc @@ -65,11 +65,11 @@ int moeHttpCopyHeaders(void* r, void* strs, void* buf) { }); } -int moeHttpSetHeader(void* r, void* key, void* value) { - return moeHandlerWrapper(r, [key, value](std::shared_ptr& filter) -> int { +int moeHttpSetHeaderHelper(void* r, void* key, void* value, headerAction act) { + return moeHandlerWrapper(r, [key, value, act](std::shared_ptr& filter) -> int { auto keyStr = copyGoString(key); auto valueStr = copyGoString(value); - return filter->setHeader(keyStr, valueStr); + return filter->setHeader(keyStr, valueStr, act); }); } diff --git a/src/envoy/http/golang/golang_filter.cc b/src/envoy/http/golang/golang_filter.cc index f277365c2b..7317b89949 100644 --- a/src/envoy/http/golang/golang_filter.cc +++ b/src/envoy/http/golang/golang_filter.cc @@ -584,7 +584,8 @@ int Filter::sendLocalReply(Http::Code response_code, absl::string_view body_text if (!state.isProcessingInGo()) { return CAPINotInGo; } - ENVOY_LOG(debug, "sendLocalReply, response code: {}", int(response_code)); + + ENVOY_LOG(debug, "sendLocalReply, response code: {}, body: {}", int(response_code), body_text); auto weak_ptr = weak_from_this(); state.getDispatcher().post( @@ -693,7 +694,7 @@ int Filter::copyHeaders(GoString* goStrs, char* goBuf) { return CAPIOK; } -int Filter::setHeader(absl::string_view key, absl::string_view value) { +int Filter::setHeader(absl::string_view key, absl::string_view value, headerAction act) { std::lock_guard lock(mutex_); if (has_destroyed_) { return CAPIFilterIsDestroy; @@ -705,7 +706,20 @@ int Filter::setHeader(absl::string_view key, absl::string_view value) { if (headers_ == nullptr) { return CAPIInvalidPhase; } - headers_->setCopy(Http::LowerCaseString(key), value); + + switch (act) { + case HeaderAdd: + headers_->addCopy(Http::LowerCaseString(key), value); + break; + + case HeaderSet: + headers_->setCopy(Http::LowerCaseString(key), value); + break; + + default: + ENVOY_LOG(error, "unknown header action {}, ignored", act); + } + return CAPIOK; } diff --git a/src/envoy/http/golang/golang_filter.h b/src/envoy/http/golang/golang_filter.h index b6fa1170a9..3acb9c02ee 100644 --- a/src/envoy/http/golang/golang_filter.h +++ b/src/envoy/http/golang/golang_filter.h @@ -162,7 +162,7 @@ class Filter : public Http::StreamFilter, int getHeader(absl::string_view key, GoString* goValue); int copyHeaders(GoString* goStrs, char* goBuf); - int setHeader(absl::string_view key, absl::string_view value); + int setHeader(absl::string_view key, absl::string_view value, headerAction act); int removeHeader(absl::string_view key); int copyBuffer(Buffer::Instance* buffer, char* data); int setBufferHelper(Buffer::Instance* buffer, absl::string_view& value, bufferAction action); diff --git a/test/http/golang/golang_integration_test.cc b/test/http/golang/golang_integration_test.cc index ab0817eafd..fd3b9d0f26 100644 --- a/test/http/golang/golang_integration_test.cc +++ b/test/http/golang/golang_integration_test.cc @@ -419,6 +419,58 @@ name: envoy.bootstrap.dso cleanup(); } + void testAddHeader() { + initializeSimpleFilter(BASIC); + + codec_client_ = makeHttpConnection(makeClientConnection(lookupPort("http"))); + Http::TestRequestHeaderMapImpl request_headers{ + {":method", "POST"}, {":path", "/test?add_header=1"}, + {":scheme", "http"}, {":authority", "test.com"}, + {"x-test-header-0", "foo"}, + }; + + auto encoder_decoder = codec_client_->startRequest(request_headers); + Http::RequestEncoder& request_encoder = encoder_decoder.first; + auto response = std::move(encoder_decoder.second); + codec_client_->sendData(request_encoder, "", true); + + waitForNextUpstreamRequest(); + + EXPECT_EQ("foo", upstream_request_->headers() + .get(Http::LowerCaseString("x-test-header-0"))[0] + ->value() + .getStringView()); + EXPECT_EQ("bar", upstream_request_->headers() + .get(Http::LowerCaseString("x-test-header-0"))[1] + ->value() + .getStringView()); + EXPECT_EQ("baz", upstream_request_->headers() + .get(Http::LowerCaseString("x-test-header-1"))[0] + ->value() + .getStringView()); + + Http::TestResponseHeaderMapImpl response_headers{ + {":status", "200"}, {"x-test-header-0", "foo"}}; + upstream_request_->encodeHeaders(response_headers, true); + + ASSERT_TRUE(response->waitForEndStream()); + + EXPECT_EQ("foo", response->headers() + .get(Http::LowerCaseString("x-test-header-0"))[0] + ->value() + .getStringView()); + EXPECT_EQ("bar", response->headers() + .get(Http::LowerCaseString("x-test-header-0"))[1] + ->value() + .getStringView()); + EXPECT_EQ("baz", response->headers() + .get(Http::LowerCaseString("x-test-header-1"))[0] + ->value() + .getStringView()); + + cleanup(); + } + void testRouteConfig(std::string domain, std::string path, bool header_0_existing, std::string set_header) { initializeSimpleFilter(ROUTECONFIG); @@ -589,6 +641,8 @@ TEST_P(GolangIntegrationTest, Basic) { testBasic("/test"); } TEST_P(GolangIntegrationTest, Async) { testBasic("/test?async=1"); } +TEST_P(GolangIntegrationTest, AddHeader) { testAddHeader(); } + TEST_P(GolangIntegrationTest, DataBuffer_DecodeHeader) { testBasic("/test?databuffer=decode-header"); } diff --git a/test/http/golang/test_data/basic/filter.go b/test/http/golang/test_data/basic/filter.go index e33987d376..7b43d27275 100644 --- a/test/http/golang/test_data/basic/filter.go +++ b/test/http/golang/test_data/basic/filter.go @@ -28,6 +28,7 @@ type filter struct { localreplay string // send local reply databuffer string // return api.Stop panic string // trigger panic in which phase + add_header bool // add header } func parseQuery(path string) url.Values { @@ -54,13 +55,16 @@ func (f *filter) initRequest(header api.HeaderMap) { if f.query_params.Get("decode_localrepaly") != "" { f.data_sleep = true } + if f.query_params.Get("add_header") != "" { + f.add_header = true + } f.databuffer = f.query_params.Get("databuffer") f.localreplay = f.query_params.Get("localreply") f.panic = f.query_params.Get("panic") } func (f *filter) fail(msg string, a ...any) api.StatusType { - body := fmt.Sprintf(msg, a) + body := fmt.Sprintf(msg, a...) f.callbacks.SendLocalReply(500, body, nil, -1, "") return api.LocalReply } @@ -80,6 +84,26 @@ func (f *filter) decodeHeaders(header api.RequestHeaderMap, endStream bool) api. if strings.Contains(f.localreplay, "decode-header") { return f.sendLocalReply("decode-header") } + if f.add_header { + // Trigger the cache + header.Get("x-test-header-0") + // Add to existed header + header.Add("x-test-header-0", "bar") + // Add to non-existed header + header.Add("x-test-header-1", "baz") + + // check the cache + hdrs := header.Values("x-test-header-0") + if len(hdrs) != 2 || hdrs[0] != "foo" || hdrs[1] != "bar" { + return f.fail("header Values x-test-header-0: unexpected %v", hdrs) + } + + hdrs = header.Values("x-test-header-1") + if len(hdrs) != 1 || hdrs[0] != "baz" { + return f.fail("header Values x-test-header-1: unexpected %v", hdrs) + } + return api.Continue + } origin, found := header.Get("x-test-header-0") if found { @@ -112,6 +136,9 @@ func (f *filter) decodeData(buffer api.BufferInstance, endStream bool) api.Statu if f.sleep || f.data_sleep { time.Sleep(time.Millisecond * 100) // sleep 100 ms } + if f.add_header { + return api.Continue + } if strings.Contains(f.localreplay, "decode-data") { return f.sendLocalReply("decode-data") } @@ -153,6 +180,26 @@ func (f *filter) encodeHeaders(header api.ResponseHeaderMap, endStream bool) api if strings.Contains(f.localreplay, "encode-header") { return f.sendLocalReply("encode-header") } + if f.add_header { + // Trigger the cache + header.Get("x-test-header-0") + // Add to existed header + header.Add("x-test-header-0", "bar") + // Add to non-existed header + header.Add("x-test-header-1", "baz") + + // check the cache + hdrs := header.Values("x-test-header-0") + if len(hdrs) != 2 || hdrs[0] != "foo" || hdrs[1] != "bar" { + return f.fail("header Values x-test-header-0: unexpected %v", hdrs) + } + + hdrs = header.Values("x-test-header-1") + if len(hdrs) != 1 || hdrs[0] != "baz" { + return f.fail("header Values x-test-header-1: unexpected %v", hdrs) + } + return api.Continue + } origin, found := header.Get("x-test-header-0") if found { @@ -184,6 +231,9 @@ func (f *filter) encodeData(buffer api.BufferInstance, endStream bool) api.Statu if f.sleep || f.data_sleep { time.Sleep(time.Millisecond * 100) // sleep 100 ms } + if f.add_header { + return api.Continue + } if strings.Contains(f.localreplay, "encode-data") { return f.sendLocalReply("encode-data") }