diff --git a/cmd/query/app/token_propagation_hander_test.go b/cmd/query/app/token_propagation_hander_test.go index 67d7de02342..8e1f2a5428b 100644 --- a/cmd/query/app/token_propagation_hander_test.go +++ b/cmd/query/app/token_propagation_hander_test.go @@ -51,16 +51,19 @@ func Test_bearTokenPropagationHandler(t *testing.T) { } testCases := []struct { - name string - sendHeader bool - header string - handler func(stop *sync.WaitGroup) http.HandlerFunc + name string + sendHeader bool + headerValue string + headerName string + handler func(stop *sync.WaitGroup) http.HandlerFunc }{ - { name:"Bearer token", sendHeader: true, header: "Bearer " + bearerToken, handler:validTokenHandler}, - { name:"Invalid header",sendHeader: true, header: bearerToken, handler:emptyHandler}, - { name:"No header", sendHeader: false, handler:emptyHandler}, - { name:"Basic Auth", sendHeader: true, header: "Basic " + bearerToken, handler:emptyHandler}, - { name:"X-Forwarded-Access-Token", sendHeader: true, header: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Bearer token", sendHeader: true, headerName:"Authorization", headerValue: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Raw bearer token",sendHeader: true, headerName:"Authorization", headerValue: bearerToken, handler:validTokenHandler}, + { name:"No headerValue", sendHeader: false, headerName:"Authorization", handler:emptyHandler}, + { name:"Basic Auth", sendHeader: true, headerName:"Authorization", headerValue: "Basic " + bearerToken, handler:emptyHandler}, + { name:"X-Forwarded-Access-Token", headerName:"X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken, handler:validTokenHandler}, + { name:"Invalid header", headerName:"X-Forwarded-Access-Token", sendHeader: true, headerValue: "Bearer " + bearerToken + " another stuff", handler:emptyHandler}, + } for _, testCase := range testCases { @@ -73,7 +76,7 @@ func Test_bearTokenPropagationHandler(t *testing.T) { req , err := http.NewRequest("GET", server.URL, nil) assert.Nil(t,err) if testCase.sendHeader { - req.Header.Add("Authorization", testCase.header) + req.Header.Add(testCase.headerName, testCase.headerValue) } _, err = httpClient.Do(req) assert.Nil(t, err) diff --git a/cmd/query/app/token_propagation_handler.go b/cmd/query/app/token_propagation_handler.go index c16b9ece662..26a73e76e14 100644 --- a/cmd/query/app/token_propagation_handler.go +++ b/cmd/query/app/token_propagation_handler.go @@ -15,6 +15,7 @@ package app import ( + "log" "net/http" "strings" @@ -26,11 +27,15 @@ import ( func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := r.Context() + logger.Info("Propagating bearer token") + log.Print(r) authHeaderValue := r.Header.Get("Authorization") - // If no Authorization header is present, try with X-Forwarded-Access-Token + // If no Authorization headerValue is present, try with X-Forwarded-Access-Token if authHeaderValue == "" { authHeaderValue = r.Header.Get("X-Forwarded-Access-Token") } + logger.Info("Token: " + authHeaderValue) + if authHeaderValue != "" { headerValue := strings.Split(authHeaderValue, " ") token := "" @@ -38,9 +43,15 @@ func bearerTokenPropagationHandler(logger *zap.Logger, h http.Handler) http.Hand // Make sure we only capture bearer token , not other types like Basic auth. if headerValue[0] == "Bearer" { token = headerValue[1] + } else { + logger.Warn("Unsupported type of token " + headerValue[0] + " skipping token propagation") } + } else if len(headerValue) == 1 { + // Tread all value as a token + logger.Info("Token type does not specified in authorization header, treating all value as the bearer token") + token = authHeaderValue } else { - logger.Warn("Invalid authorization header, skipping bearer token propagation") + logger.Warn("Invalid authorization header value, skipping token propagation") } h.ServeHTTP(w, r.WithContext(spanstore.ContextWithBearerToken(ctx, token))) } else {