diff --git a/extension/googleclientauthextension/http_test.go b/extension/googleclientauthextension/http_test.go index 5750959ba..2e53ecca7 100644 --- a/extension/googleclientauthextension/http_test.go +++ b/extension/googleclientauthextension/http_test.go @@ -16,7 +16,11 @@ package googleclientauthextension // import "github.com/GoogleCloudPlatform/open import ( "context" + "encoding/json" + "fmt" "net/http" + "net/http/httptest" + "os" "testing" "github.com/stretchr/testify/assert" @@ -24,8 +28,30 @@ import ( "google.golang.org/api/idtoken" ) +func init() { + // Make sure metadata.OnGCE always returns true, since the result is + // cached. + os.Setenv("GCE_METADATA_HOST", "127.0.0.1") +} + func TestRoundTripper(t *testing.T) { - t.Setenv("GOOGLE_APPLICATION_CREDENTIALS", "testdata/fake_creds.json") + fakeToken := oauth2.Token{ + AccessToken: "accessToken", + TokenType: "tokenType", + RefreshToken: "refreshToken", + ExpiresIn: 1, + } + b, err := json.Marshal(fakeToken) + assert.NoError(t, err) + tokenString := string(b) + // Mimic metadata server, and return the fake access token. + srvProvidingTokens := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + fmt.Fprint(w, tokenString) + })) + defer srvProvidingTokens.Close() + t.Setenv("GCE_METADATA_HOST", srvProvidingTokens.Listener.Addr().String()) + ca := clientAuthenticator{ config: &Config{ Project: "my-project", @@ -34,14 +60,26 @@ func TestRoundTripper(t *testing.T) { TokenHeader: authorizationHeader, }, } - err := ca.Start(context.Background(), nil) + + err = ca.Start(t.Context(), nil) assert.NoError(t, err) rt, err := ca.RoundTripper(roundTripperFunc(func(r *http.Request) (*http.Response, error) { - return nil, nil + assert.Equal(t, r.Header.Get("X-Goog-User-Project"), "other-project") + assert.Equal(t, r.Header.Get("X-Goog-Project-ID"), "my-project") + assert.Equal(t, r.Header.Get("foo"), "bar") + if r.Header.Get("Authorization") != "tokenType accessToken" { + // Don't print this out in-case it is a real access token. + t.Error("Authorization header was incorrect. FindDefaultCredentials may have found real credentials.") + } + return &http.Response{}, nil })) assert.NotNil(t, rt) assert.NoError(t, err) + header := make(http.Header) + header.Set("foo", "bar") + _, err = rt.RoundTrip(&http.Request{Header: header}) + assert.NoError(t, err) } func TestRoundTripperWithIDToken(t *testing.T) {