From 156fd876db314155dc655bd36837899031dacfd8 Mon Sep 17 00:00:00 2001 From: Walter Schulze Date: Thu, 25 Jan 2024 02:48:47 -0800 Subject: [PATCH] add support for setting headers via context Summary: In this diff we add a way to write headers via context This is for the use case of writing the kCat to the header in the follow up diff D53042701 We can AddHeaders to context with the new AddHeader method This context can be passed to the generated thrift RPC ChannelClient methods These generated methods calls SerialChannel.Call with context in the thrift library SerialChannel.Call calls SerialChannel.sendMsg with context SerialChannel.sendMsg adds the headers to the transport using setHeaders When request.Write is called it calls writeVarHeader in header.go writeVarHeader calls writeInfoHeaders writeInfoHeaders writes both the persistent and normal headers to the transport It has been tested before that the way persistentHeaders are written is a valid way to write headers that can be picked up by a C++ thrift client that is looking for the kCat header. Since the headers are written in the same way as the persistentHeaders we are confident that this will work. At the end of SerialChannel.sendMsg we call Flush The normal headers are cleared in the Flush function so that they are only written for the single request and not on every request like the persistent headers. Reviewed By: slasher-, echistyakov Differential Revision: D53042527 fbshipit-source-id: a415f6377b48a6d7d44c9bdfd5c2e0d1a490897f --- thrift/lib/go/thrift/context_headers.go | 71 ++++++++++++++++++++ thrift/lib/go/thrift/context_headers_test.go | 50 ++++++++++++++ thrift/lib/go/thrift/serial_channel.go | 10 ++- thrift/lib/go/thrift/transport_exception.go | 1 + 4 files changed, 129 insertions(+), 3 deletions(-) create mode 100644 thrift/lib/go/thrift/context_headers.go create mode 100644 thrift/lib/go/thrift/context_headers_test.go diff --git a/thrift/lib/go/thrift/context_headers.go b/thrift/lib/go/thrift/context_headers.go new file mode 100644 index 00000000000..1a73bb0870d --- /dev/null +++ b/thrift/lib/go/thrift/context_headers.go @@ -0,0 +1,71 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "context" + "fmt" +) + +// The headersKeyType type is unexported to prevent collisions with context keys. +type headersKeyType int + +const headersKey headersKeyType = 0 + +// AddHeader adds a header to the context, which will be sent as part of the request. +// AddHeader can be called multiple times to add multiple headers. +// These headers are not persistent and will only be sent with the current request. +func AddHeader(ctx context.Context, key string, value string) (context.Context, error) { + headersMap := make(map[string]string) + if headers := ctx.Value(headersKey); headers != nil { + var ok bool + headersMap, ok = headers.(map[string]string) + if !ok { + return nil, NewTransportException(INVALID_HEADERS_TYPE, "Headers key in context value is not map[string]string") + } + } + headersMap[key] = value + ctx = context.WithValue(ctx, headersKey, headersMap) + return ctx, nil +} + +// setHeaders sets the Headers in the transport to send with the request. +// These headers will be written via the Write method, inside the Call method for each generated request. +// These Headers will be cleared with Flush, as they are not persistent. +func setHeaders(ctx context.Context, transport Transport) error { + if ctx == nil { + return nil + } + headers := ctx.Value(headersKey) + if headers == nil { + return nil + } + headersMap, ok := headers.(map[string]string) + if !ok { + return NewTransportException(INVALID_HEADERS_TYPE, "Headers key in context value is not map[string]string") + } + switch t := transport.(type) { + case *HeaderTransport: + for k, v := range headersMap { + t.SetHeader(k, v) + } + default: + // TODO(T173277635): Support Rocket Transport + return NewTransportException(NOT_IMPLEMENTED, fmt.Sprintf("setHeaders not implemented for transport type %T", t)) + } + return nil +} diff --git a/thrift/lib/go/thrift/context_headers_test.go b/thrift/lib/go/thrift/context_headers_test.go new file mode 100644 index 00000000000..ae6cf9d7d8d --- /dev/null +++ b/thrift/lib/go/thrift/context_headers_test.go @@ -0,0 +1,50 @@ +/* + * Copyright (c) Meta Platforms, Inc. and affiliates. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package thrift + +import ( + "context" + "testing" + + "github.com/stretchr/testify/assert" +) + +func TestSomeHeaders(t *testing.T) { + ctx := context.Background() + want := map[string]string{"key1": "value1", "key2": "value2"} + var err error + for key, value := range want { + ctx, err = AddHeader(ctx, key, value) + if err != nil { + t.Fatal(err) + } + } + transport := NewHeaderTransport(NewMemoryBuffer()) + if err := setHeaders(ctx, transport); err != nil { + t.Fatal(err) + } + got := transport.Headers() + assert.Equal(t, want, got) +} + +// somewhere we are still passing context as nil, so we need to support this for now +func TestSetNilHeaders(t *testing.T) { + transport := NewHeaderTransport(NewMemoryBuffer()) + if err := setHeaders(nil, transport); err != nil { + t.Fatal(err) + } +} diff --git a/thrift/lib/go/thrift/serial_channel.go b/thrift/lib/go/thrift/serial_channel.go index 90f9d532749..f5fcbb2b8e4 100644 --- a/thrift/lib/go/thrift/serial_channel.go +++ b/thrift/lib/go/thrift/serial_channel.go @@ -38,10 +38,14 @@ func NewSerialChannel(protocol Protocol) *SerialChannel { } } -func (c *SerialChannel) sendMsg(method string, request IRequest, msgType MessageType) (int32, error) { +func (c *SerialChannel) sendMsg(ctx context.Context, method string, request IRequest, msgType MessageType) (int32, error) { c.seqID++ seqID := c.seqID + if err := setHeaders(ctx, c.protocol.Transport()); err != nil { + return seqID, err + } + if err := c.protocol.WriteMessageBegin(method, msgType, seqID); err != nil { return seqID, err } @@ -119,7 +123,7 @@ func (c *SerialChannel) Call(ctx context.Context, method string, request IReques c.lock.Lock() defer c.lock.Unlock() - seqID, err := c.sendMsg(method, request, CALL) + seqID, err := c.sendMsg(ctx, method, request, CALL) if err != nil { return err } @@ -138,7 +142,7 @@ func (c *SerialChannel) Oneway(ctx context.Context, method string, request IRequ c.lock.Lock() defer c.lock.Unlock() - _, err := c.sendMsg(method, request, ONEWAY) + _, err := c.sendMsg(ctx, method, request, ONEWAY) if err != nil { return err } diff --git a/thrift/lib/go/thrift/transport_exception.go b/thrift/lib/go/thrift/transport_exception.go index d154498c6e7..3cd931d0f70 100644 --- a/thrift/lib/go/thrift/transport_exception.go +++ b/thrift/lib/go/thrift/transport_exception.go @@ -47,6 +47,7 @@ const ( SSL_ERROR = 12 COULD_NOT_BIND = 13 NETWORK_ERROR = 15 + INVALID_HEADERS_TYPE = 16 ) type transportException struct {