Skip to content

Commit

Permalink
add support for setting headers via context
Browse files Browse the repository at this point in the history
Summary:
In this diff we add a way to write headers via context
This is for the use case of writing the kCat to the header in the follow up diff D53042701

We can AddHeaders to context with the new AddHeader method
This context can be passed to the generated thrift RPC ChannelClient methods
These generated methods calls SerialChannel.Call with context in the thrift library
SerialChannel.Call calls SerialChannel.sendMsg with context
SerialChannel.sendMsg adds the headers to the transport using setHeaders
When request.Write is called it calls writeVarHeader in header.go
writeVarHeader calls writeInfoHeaders
writeInfoHeaders writes both the persistent and normal headers to the transport

It has been tested before that the way persistentHeaders are written is a valid way to write headers that can be picked up by a C++ thrift client that is looking for the kCat header.
Since the headers are written in the same way as the persistentHeaders we are confident that this will work.

At the end of SerialChannel.sendMsg we call Flush
The normal headers are cleared in the Flush function so that they are only written for the single request and not on every request like the persistent headers.

Reviewed By: slasher-, echistyakov

Differential Revision: D53042527

fbshipit-source-id: a415f6377b48a6d7d44c9bdfd5c2e0d1a490897f
  • Loading branch information
awalterschulze authored and facebook-github-bot committed Jan 25, 2024
1 parent 72d8cc3 commit 156fd87
Show file tree
Hide file tree
Showing 4 changed files with 129 additions and 3 deletions.
71 changes: 71 additions & 0 deletions thrift/lib/go/thrift/context_headers.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"context"
"fmt"
)

// The headersKeyType type is unexported to prevent collisions with context keys.
type headersKeyType int

const headersKey headersKeyType = 0

// AddHeader adds a header to the context, which will be sent as part of the request.
// AddHeader can be called multiple times to add multiple headers.
// These headers are not persistent and will only be sent with the current request.
func AddHeader(ctx context.Context, key string, value string) (context.Context, error) {
headersMap := make(map[string]string)
if headers := ctx.Value(headersKey); headers != nil {
var ok bool
headersMap, ok = headers.(map[string]string)
if !ok {
return nil, NewTransportException(INVALID_HEADERS_TYPE, "Headers key in context value is not map[string]string")
}
}
headersMap[key] = value
ctx = context.WithValue(ctx, headersKey, headersMap)
return ctx, nil
}

// setHeaders sets the Headers in the transport to send with the request.
// These headers will be written via the Write method, inside the Call method for each generated request.
// These Headers will be cleared with Flush, as they are not persistent.
func setHeaders(ctx context.Context, transport Transport) error {
if ctx == nil {
return nil
}
headers := ctx.Value(headersKey)
if headers == nil {
return nil
}
headersMap, ok := headers.(map[string]string)
if !ok {
return NewTransportException(INVALID_HEADERS_TYPE, "Headers key in context value is not map[string]string")
}
switch t := transport.(type) {
case *HeaderTransport:
for k, v := range headersMap {
t.SetHeader(k, v)
}
default:
// TODO(T173277635): Support Rocket Transport
return NewTransportException(NOT_IMPLEMENTED, fmt.Sprintf("setHeaders not implemented for transport type %T", t))
}
return nil
}
50 changes: 50 additions & 0 deletions thrift/lib/go/thrift/context_headers_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
/*
* Copyright (c) Meta Platforms, Inc. and affiliates.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package thrift

import (
"context"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSomeHeaders(t *testing.T) {
ctx := context.Background()
want := map[string]string{"key1": "value1", "key2": "value2"}
var err error
for key, value := range want {
ctx, err = AddHeader(ctx, key, value)
if err != nil {
t.Fatal(err)
}
}
transport := NewHeaderTransport(NewMemoryBuffer())
if err := setHeaders(ctx, transport); err != nil {
t.Fatal(err)
}
got := transport.Headers()
assert.Equal(t, want, got)
}

// somewhere we are still passing context as nil, so we need to support this for now
func TestSetNilHeaders(t *testing.T) {
transport := NewHeaderTransport(NewMemoryBuffer())
if err := setHeaders(nil, transport); err != nil {
t.Fatal(err)
}
}
10 changes: 7 additions & 3 deletions thrift/lib/go/thrift/serial_channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,14 @@ func NewSerialChannel(protocol Protocol) *SerialChannel {
}
}

func (c *SerialChannel) sendMsg(method string, request IRequest, msgType MessageType) (int32, error) {
func (c *SerialChannel) sendMsg(ctx context.Context, method string, request IRequest, msgType MessageType) (int32, error) {
c.seqID++
seqID := c.seqID

if err := setHeaders(ctx, c.protocol.Transport()); err != nil {
return seqID, err
}

if err := c.protocol.WriteMessageBegin(method, msgType, seqID); err != nil {
return seqID, err
}
Expand Down Expand Up @@ -119,7 +123,7 @@ func (c *SerialChannel) Call(ctx context.Context, method string, request IReques
c.lock.Lock()
defer c.lock.Unlock()

seqID, err := c.sendMsg(method, request, CALL)
seqID, err := c.sendMsg(ctx, method, request, CALL)
if err != nil {
return err
}
Expand All @@ -138,7 +142,7 @@ func (c *SerialChannel) Oneway(ctx context.Context, method string, request IRequ
c.lock.Lock()
defer c.lock.Unlock()

_, err := c.sendMsg(method, request, ONEWAY)
_, err := c.sendMsg(ctx, method, request, ONEWAY)
if err != nil {
return err
}
Expand Down
1 change: 1 addition & 0 deletions thrift/lib/go/thrift/transport_exception.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,7 @@ const (
SSL_ERROR = 12
COULD_NOT_BIND = 13
NETWORK_ERROR = 15
INVALID_HEADERS_TYPE = 16
)

type transportException struct {
Expand Down

0 comments on commit 156fd87

Please sign in to comment.