Skip to content

Commit dabce82

Browse files
committed
http/requestid: fix boundary conditions
We don't want to allow arbitrarily long request ID-s. Don't log inside a package. Change-Id: I66bafd58ca7c587e82f25c3e8653ba79085c0080
1 parent cbf38d7 commit dabce82

File tree

3 files changed

+45
-29
lines changed

3 files changed

+45
-29
lines changed

http/doc.go

-5
This file was deleted.

http/requestid/requestid.go

+19-14
Original file line numberDiff line numberDiff line change
@@ -6,12 +6,9 @@ package requestid
66
import (
77
"context"
88
"crypto/rand"
9-
"encoding/base64"
10-
"fmt"
11-
"log"
129
"net/http"
1310

14-
"github.com/spacemonkeygo/monkit/v3"
11+
"storj.io/common/base58"
1512
)
1613

1714
// contextKey is the key that holds the unique request ID in a request context.
@@ -20,13 +17,25 @@ type contextKey struct{}
2017
// HeaderKey is the header key for the request ID.
2118
const HeaderKey = "X-Request-Id"
2219

20+
// MaxRequestID is the maximum allowed length for a request id.
21+
const MaxRequestID = 64
22+
2323
// AddToContext uses adds a unique requestid to the context and the response headers
2424
// of each request.
2525
func AddToContext(h http.Handler) http.Handler {
2626
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2727
requestID := r.Header.Get(HeaderKey)
28+
if len(requestID) > MaxRequestID {
29+
requestID = ""
30+
}
2831
if requestID == "" {
29-
requestID = generateRequestID()
32+
var err error
33+
requestID, err = generateRandomID()
34+
if err != nil {
35+
// If we fail to generate a random ID, then don't use one.
36+
h.ServeHTTP(w, r)
37+
return
38+
}
3039
}
3140

3241
w.Header().Set(HeaderKey, requestID)
@@ -48,15 +57,11 @@ func Propagate(ctx context.Context, req *http.Request) {
4857
req.Header.Set(HeaderKey, FromContext(ctx))
4958
}
5059

51-
// generateRequestID generates a random request ID using crypto/rand.
52-
// in case of an unlikely error, it falls back to using monkit.NewId().
53-
func generateRequestID() string {
54-
idBytes := make([]byte, 16)
55-
_, err := rand.Read(idBytes)
60+
func generateRandomID() (string, error) {
61+
var data [8]byte
62+
_, err := rand.Read(data[:])
5663
if err != nil {
57-
log.Printf("error generating request ID: %v", err)
58-
return fmt.Sprintf("%x", monkit.NewId())
64+
return "", err
5965
}
60-
61-
return base64.RawURLEncoding.EncodeToString(idBytes)
66+
return base58.Encode(data[:]), nil
6267
}

http/requestid/requestid_test.go

+26-10
Original file line numberDiff line numberDiff line change
@@ -17,30 +17,46 @@ import (
1717
func TestAddToContext(t *testing.T) {
1818
ctx := testcontext.New(t)
1919

20-
request, err := http.NewRequestWithContext(ctx, "GET", "", http.NoBody)
21-
require.NoError(t, err)
22-
23-
rw := httptest.NewRecorder()
24-
2520
var requestID string
26-
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
21+
handler := AddToContext(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
2722
w.WriteHeader(http.StatusOK)
2823
require.NotNil(t, r.Context().Value(contextKey{}), "RequestId should not be nil")
2924
require.NotEqual(t, "", r.Context().Value(contextKey{}).(string), "RequestId not set in Context")
3025

3126
requestID = r.Context().Value(contextKey{}).(string)
27+
}))
28+
29+
t.Run("success", func(t *testing.T) {
30+
rw := httptest.NewRecorder()
31+
32+
request, err := http.NewRequestWithContext(ctx, "GET", "", http.NoBody)
33+
require.NoError(t, err)
34+
handler.ServeHTTP(rw, request)
35+
36+
require.NotEqual(t, "", rw.Header().Get(HeaderKey), "RequestId is not set in response header")
37+
require.Equal(t, requestID, rw.Header().Get(HeaderKey), "Correct RequestId is not set in response header")
3238
})
3339

34-
newHandler := AddToContext(handler)
35-
newHandler.ServeHTTP(rw, request)
40+
t.Run("too-long", func(t *testing.T) {
41+
rw := httptest.NewRecorder()
42+
43+
request, err := http.NewRequestWithContext(ctx, "GET", "", http.NoBody)
44+
require.NoError(t, err)
45+
const tooLongKey = "01234567890123456789012345678901234567890123456789012345678901234567890123456789"
46+
request.Header.Set(HeaderKey, tooLongKey)
47+
48+
handler.ServeHTTP(rw, request)
3649

37-
require.NotEqual(t, "", rw.Header().Get(HeaderKey), "RequestId is not set in response header")
38-
require.Equal(t, requestID, rw.Header().Get(HeaderKey), "Correct RequestId is not set in response header")
50+
require.NotEqual(t, "", rw.Header().Get(HeaderKey), "RequestId is not set in response header")
51+
require.NotEqual(t, tooLongKey, requestID)
52+
})
3953
}
4054

4155
func TestPropagate(t *testing.T) {
4256
ctx := testcontext.New(t)
4357

58+
require.Equal(t, "", FromContext(ctx))
59+
4460
requestID := "test-request-id"
4561
reqctx := context.WithValue(ctx, contextKey{}, requestID)
4662

0 commit comments

Comments
 (0)