diff --git a/api/utils/gcp/endpoint.go b/api/utils/gcp/endpoint.go new file mode 100644 index 0000000000000..a178d0f90b23e --- /dev/null +++ b/api/utils/gcp/endpoint.go @@ -0,0 +1,29 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gcp + +import ( + "strings" +) + +const ( + // gcpEndpointSuffix is hostname suffix used to determine if hostname is a GCP endpoint + gcpEndpointSuffix = ".googleapis.com" +) + +// IsGCPEndpoint returns true if hostname is a GCP endpoint +func IsGCPEndpoint(hostname string) bool { + return strings.HasSuffix(hostname, gcpEndpointSuffix) +} diff --git a/api/utils/gcp/endpoint_test.go b/api/utils/gcp/endpoint_test.go new file mode 100644 index 0000000000000..4e28596f766ae --- /dev/null +++ b/api/utils/gcp/endpoint_test.go @@ -0,0 +1,60 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gcp + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestIsGCPEndpoint(t *testing.T) { + tests := []struct { + name string + hostname string + want bool + }{ + { + name: "compute googleapis", + hostname: "compute.googleapis.com", + want: true, + }, + { + name: "top level googleapis", + hostname: "googleapis.com", + want: false, + }, + { + name: "localhost", + hostname: "localhost", + want: false, + }, + { + name: "fake googleapis", + hostname: "compute.googleapis.com.fake.com", + want: false, + }, + { + name: "top level-like fake googleapis", + hostname: "googleapis.com.fake.com", + want: false, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.want, IsGCPEndpoint(tt.hostname)) + }) + } +} diff --git a/lib/srv/app/azure/handler.go b/lib/srv/app/azure/handler.go index 9ff67a2877899..ca66f544166bb 100644 --- a/lib/srv/app/azure/handler.go +++ b/lib/srv/app/azure/handler.go @@ -89,8 +89,13 @@ type handler struct { tokenCache *utils.FnCache } -// NewAzureHandler creates a new instance of an http.Handler for Azure requests. +// NewAzureHandler creates a new instance of a http.Handler for Azure requests. func NewAzureHandler(ctx context.Context, config HandlerConfig) (http.Handler, error) { + return newAzureHandler(ctx, config) +} + +// newAzureHandler creates a new instance of a handler for Azure requests. Used by NewAzureHandler and in tests. +func newAzureHandler(ctx context.Context, config HandlerConfig) (*handler, error) { if err := config.CheckAndSetDefaults(); err != nil { return nil, trace.Wrap(err) } @@ -282,10 +287,23 @@ const getTokenTimeout = time.Second * 5 func (s *handler) getToken(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { key := cacheKey{managedIdentity, scope} - timeoutCtx, cancel := context.WithTimeout(ctx, getTokenTimeout) + cancelCtx, cancel := context.WithCancel(ctx) defer cancel() - return utils.FnCacheGet(timeoutCtx, s.tokenCache, key, func(ctx context.Context) (*azcore.AccessToken, error) { - return s.getAccessToken(ctx, managedIdentity, scope) - }) + var tokenResult *azcore.AccessToken + var errorResult error + + go func() { + tokenResult, errorResult = utils.FnCacheGet(cancelCtx, s.tokenCache, key, func(ctx context.Context) (*azcore.AccessToken, error) { + return s.getAccessToken(ctx, managedIdentity, scope) + }) + cancel() + }() + + select { + case <-s.Clock.After(getTokenTimeout): + return nil, trace.Wrap(context.DeadlineExceeded, "timeout waiting for access token for %v", getTokenTimeout) + case <-cancelCtx.Done(): + return tokenResult, errorResult + } } diff --git a/lib/srv/app/azure/handler_test.go b/lib/srv/app/azure/handler_test.go new file mode 100644 index 0000000000000..85a0f2a035337 --- /dev/null +++ b/lib/srv/app/azure/handler_test.go @@ -0,0 +1,148 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package azure + +import ( + "context" + "testing" + "time" + + "github.com/Azure/azure-sdk-for-go/sdk/azcore" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/require" +) + +func TestForwarder_getToken(t *testing.T) { + t.Parallel() + + type testCase struct { + name string + + config HandlerConfig + + managedIdentity string + scope string + + wantToken *azcore.AccessToken + checkErr require.ErrorAssertionFunc + } + + var tests []testCase + + tests = []testCase{ + { + name: "base case", + config: HandlerConfig{ + getAccessToken: func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { + if managedIdentity != "MY_IDENTITY" { + return nil, trace.BadParameter("wrong managedIdentity") + } + if scope != "MY_SCOPE" { + return nil, trace.BadParameter("wrong scope") + } + return &azcore.AccessToken{Token: "foobar"}, nil + }, + }, + managedIdentity: "MY_IDENTITY", + scope: "MY_SCOPE", + wantToken: &azcore.AccessToken{Token: "foobar"}, + checkErr: require.NoError, + }, + { + name: "timeout", + config: HandlerConfig{ + Clock: clockwork.NewFakeClock(), + getAccessToken: func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { + // find the fake clock from above + var clock clockwork.FakeClock + for _, test := range tests { + if test.name == "timeout" { + clock = test.config.Clock.(clockwork.FakeClock) + } + } + + clock.Advance(getTokenTimeout) + clock.Sleep(getTokenTimeout * 2) + return &azcore.AccessToken{Token: "foobar"}, nil + }, + }, + checkErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "timeout waiting for access token for 5s") + require.ErrorIs(t, err, context.DeadlineExceeded) + }, + }, + { + name: "non-timeout error", + config: HandlerConfig{ + getAccessToken: func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { + return nil, trace.BadParameter("bad param foo") + }, + }, + checkErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "bad param foo") + require.True(t, trace.IsBadParameter(err)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + ctx := context.Background() + + hnd, err := newAzureHandler(ctx, tt.config) + require.NoError(t, err) + + token, err := hnd.getToken(ctx, tt.managedIdentity, tt.scope) + + require.Equal(t, tt.wantToken, token) + tt.checkErr(t, err) + }) + } +} + +func TestForwarder_getToken_cache(t *testing.T) { + ctx := context.Background() + + clock := clockwork.NewFakeClock() + + calls := 0 + hnd, err := newAzureHandler(ctx, HandlerConfig{ + Clock: clock, + getAccessToken: func(ctx context.Context, managedIdentity string, scope string) (*azcore.AccessToken, error) { + calls++ + return &azcore.AccessToken{Token: "OK"}, nil + }, + }) + require.NoError(t, err) + + // first call goes through + _, err = hnd.getToken(ctx, "", "") + require.NoError(t, err) + require.Equal(t, 1, calls) + + // second call is cached + _, err = hnd.getToken(ctx, "", "") + require.NoError(t, err) + require.Equal(t, 1, calls) + + // advance past cache expiry + clock.Advance(time.Second * 60 * 2) + + // third call goes through + _, err = hnd.getToken(ctx, "", "") + require.NoError(t, err) + require.Equal(t, 2, calls) +} diff --git a/lib/srv/app/gcp/handler.go b/lib/srv/app/gcp/handler.go new file mode 100644 index 0000000000000..13659fd785a40 --- /dev/null +++ b/lib/srv/app/gcp/handler.go @@ -0,0 +1,296 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gcp + +import ( + "bytes" + "context" + "fmt" + "net/http" + "time" + + gcpcredentials "cloud.google.com/go/iam/credentials/apiv1" + "cloud.google.com/go/iam/credentials/apiv1/credentialspb" + "github.com/googleapis/gax-go/v2" + "github.com/gravitational/oxy/forward" + oxyutils "github.com/gravitational/oxy/utils" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/sirupsen/logrus" + + "github.com/gravitational/teleport/api/utils/gcp" + "github.com/gravitational/teleport/lib/cloud" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/srv/app/common" + "github.com/gravitational/teleport/lib/utils" + awsutils "github.com/gravitational/teleport/lib/utils/aws" +) + +// iamCredentialsClient is an interface that defines the methods which we use from IAM Service Account Credentials API. +// It is implemented by *gcpcredentials.IamCredentialsClient and can be mocked in tests unlike the concrete struct. +type iamCredentialsClient interface { + GenerateAccessToken(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) +} + +// cloudClientGCP is an interface that defines the GetGCPIAMClient method we use in this module. +type cloudClientGCP interface { + GetGCPIAMClient(context.Context) (iamCredentialsClient, error) +} + +// cloudClientGCPImpl is a wrapper around callback function implementing cloudClientGCP interface. +type cloudClientGCPImpl[T iamCredentialsClient] struct { + getGCPIAMClient func(ctx context.Context) (T, error) +} + +func (t *cloudClientGCPImpl[T]) GetGCPIAMClient(ctx context.Context) (iamCredentialsClient, error) { + return t.getGCPIAMClient(ctx) +} + +var _ cloudClientGCP = (*cloudClientGCPImpl[iamCredentialsClient])(nil) + +// HandlerConfig is the configuration for an GCP app-access handler. +type HandlerConfig struct { + // RoundTripper is the underlying transport given to an oxy Forwarder. + RoundTripper http.RoundTripper + // Log is the Logger. + Log logrus.FieldLogger + // Clock is used to override time in tests. + Clock clockwork.Clock + // cloudClientGCP holds a reference to GCP IAM client. Normally set in CheckAndSetDefaults, it is overridden in tests. + cloudClientGCP cloudClientGCP +} + +// CheckAndSetDefaults validates the HandlerConfig. +func (s *HandlerConfig) CheckAndSetDefaults() error { + if s.RoundTripper == nil { + tr, err := defaults.Transport() + if err != nil { + return trace.Wrap(err) + } + s.RoundTripper = tr + } + if s.Clock == nil { + s.Clock = clockwork.NewRealClock() + } + if s.Log == nil { + s.Log = logrus.WithField(trace.Component, "gcp:fwd") + } + if s.cloudClientGCP == nil { + clients := cloud.NewClients() + s.cloudClientGCP = &cloudClientGCPImpl[*gcpcredentials.IamCredentialsClient]{getGCPIAMClient: clients.GetGCPIAMClient} + } + return nil +} + +// Forwarder is an GCP CLI proxy service that forwards the requests to GCP API, but updates the authorization headers +// based on user identity. +type handler struct { + // config is the handler configuration. + HandlerConfig + + // fwd is used to forward requests to GCP API after the handler has rewritten them. + fwd *forward.Forwarder + + // tokenCache caches access tokens. + tokenCache *utils.FnCache +} + +// NewGCPHandler creates a new instance of a http.Handler for GCP requests. +func NewGCPHandler(ctx context.Context, config HandlerConfig) (http.Handler, error) { + return newGCPHandler(ctx, config) +} + +// newGCPHandler creates a new instance of a handler for GCP requests. Used by NewGCPHandler and in tests. +func newGCPHandler(ctx context.Context, config HandlerConfig) (*handler, error) { + if err := config.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + + tokenCache, err := utils.NewFnCache(utils.FnCacheConfig{ + TTL: time.Second * 60, + Clock: config.Clock, + Context: ctx, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + svc := &handler{ + HandlerConfig: config, + tokenCache: tokenCache, + } + + fwd, err := forward.New( + forward.RoundTripper(config.RoundTripper), + forward.ErrorHandler(oxyutils.ErrorHandlerFunc(svc.formatForwardResponseError)), + // Explicitly passing false here to be clear that we always want the host + // header to be the same as the outbound request's URL host. + forward.PassHostHeader(false), + ) + if err != nil { + return nil, trace.Wrap(err) + } + svc.fwd = fwd + return svc, nil +} + +// RoundTrip handles incoming requests and forwards them to the proper API. +func (s *handler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + if err := s.serveHTTP(w, req); err != nil { + s.formatForwardResponseError(w, req, err) + return + } +} + +// serveHTTP is a helper to simplify error handling in ServeHTTP. +func (s *handler) serveHTTP(w http.ResponseWriter, req *http.Request) error { + sessionCtx, err := common.GetSessionContext(req) + if err != nil { + return trace.Wrap(err) + } + s.Log.Debugf("Processing request, sessionId = %q, gcpServiceAccount = %q", sessionCtx.Identity.RouteToApp.SessionID, sessionCtx.Identity.RouteToApp.GCPServiceAccount) + + fwdRequest, err := s.prepareForwardRequest(req, sessionCtx) + if err != nil { + return trace.Wrap(err) + } + recorder := httplib.NewResponseStatusRecorder(w) + s.fwd.ServeHTTP(recorder, fwdRequest) + status := uint32(recorder.Status()) + + if err := sessionCtx.Audit.OnRequest(req.Context(), sessionCtx, fwdRequest, status, nil); err != nil { + // log but don't return the error, because we already handed off request/response handling to the oxy forwarder. + s.Log.WithError(err).Warn("Failed to emit audit event.") + } + return nil +} + +func (s *handler) formatForwardResponseError(rw http.ResponseWriter, r *http.Request, err error) { + s.Log.WithError(err).Debugf("Failed to process request.") + common.SetTeleportAPIErrorHeader(rw, err) + + // Convert trace error type to HTTP and write response. + code := trace.ErrorToCode(err) + http.Error(rw, http.StatusText(code), code) +} + +// prepareForwardRequest prepares a request for forwarding, updating headers and target host. Several checks are made along the way. +func (s *handler) prepareForwardRequest(r *http.Request, sessionCtx *common.SessionContext) (*http.Request, error) { + forwardedHost := r.Header.Get("X-Forwarded-Host") + if !gcp.IsGCPEndpoint(forwardedHost) { + return nil, trace.AccessDenied("%q is not a GCP endpoint", forwardedHost) + } + + payload, err := awsutils.GetAndReplaceReqBody(r) + if err != nil { + return nil, trace.Wrap(err) + } + + reqCopy, err := http.NewRequest(r.Method, r.URL.String(), bytes.NewReader(payload)) + if err != nil { + return nil, trace.Wrap(err) + } + + reqCopy.URL.Scheme = "https" + reqCopy.URL.Host = forwardedHost + reqCopy.Header = r.Header.Clone() + + err = s.replaceAuthHeaders(r, sessionCtx, reqCopy) + if err != nil { + return nil, trace.Wrap(err) + } + + return reqCopy, trace.Wrap(err) +} + +func (s *handler) replaceAuthHeaders(r *http.Request, sessionCtx *common.SessionContext, reqCopy *http.Request) error { + auth := reqCopy.Header.Get("Authorization") + if auth == "" { + s.Log.Debugf("No Authorization header present, skipping replacement.") + return nil + } + + token, err := s.getToken(r.Context(), sessionCtx.Identity.RouteToApp.GCPServiceAccount) + if err != nil { + return trace.Wrap(err) + } + + // Set new authorization + reqCopy.Header.Set("Authorization", "Bearer "+token.AccessToken) + return nil +} + +type cacheKey struct { + serviceAccount string +} + +const getTokenTimeout = time.Second * 5 + +// defaultScopeList is a fixed list of scopes requested for a token. +// If needed we can extend it or make it configurable. +// For scope documentation see: https://developers.google.com/identity/protocols/oauth2/scopes +var defaultScopeList = []string{ + "https://www.googleapis.com/auth/cloud-platform", + + "openid", + "https://www.googleapis.com/auth/userinfo.email", + "https://www.googleapis.com/auth/appengine.admin", + "https://www.googleapis.com/auth/sqlservice.login", + "https://www.googleapis.com/auth/compute", +} + +func (s *handler) getToken(ctx context.Context, serviceAccount string) (*credentialspb.GenerateAccessTokenResponse, error) { + key := cacheKey{serviceAccount} + + cancelCtx, cancel := context.WithCancel(ctx) + defer cancel() + + var tokenResult *credentialspb.GenerateAccessTokenResponse + var errorResult error + + go func() { + tokenResult, errorResult = utils.FnCacheGet(cancelCtx, s.tokenCache, key, func(ctx context.Context) (*credentialspb.GenerateAccessTokenResponse, error) { + return s.generateAccessToken(ctx, serviceAccount, defaultScopeList) + }) + cancel() + }() + + select { + case <-s.Clock.After(getTokenTimeout): + return nil, trace.Wrap(context.DeadlineExceeded, "timeout waiting for access token for %v", getTokenTimeout) + case <-cancelCtx.Done(): + return tokenResult, errorResult + } +} + +func (s *handler) generateAccessToken(ctx context.Context, serviceAccount string, scopes []string) (*credentialspb.GenerateAccessTokenResponse, error) { + client, err := s.cloudClientGCP.GetGCPIAMClient(ctx) + if err != nil { + return nil, trace.Wrap(err) + } + + request := &credentialspb.GenerateAccessTokenRequest{ + // expected format: projects/-/serviceAccounts/{ACCOUNT_EMAIL_OR_UNIQUEID} + Name: fmt.Sprintf("projects/-/serviceAccounts/%v", serviceAccount), + Scope: scopes, + } + accessToken, err := client.GenerateAccessToken(ctx, request) + if err != nil { + return nil, trace.Wrap(err) + } + + return accessToken, nil +} diff --git a/lib/srv/app/gcp/handler_test.go b/lib/srv/app/gcp/handler_test.go new file mode 100644 index 0000000000000..14fd09d2079c4 --- /dev/null +++ b/lib/srv/app/gcp/handler_test.go @@ -0,0 +1,179 @@ +// Copyright 2022 Gravitational, Inc +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package gcp + +import ( + "context" + "testing" + "time" + + "cloud.google.com/go/iam/credentials/apiv1/credentialspb" + "github.com/googleapis/gax-go/v2" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +type testIAMCredentialsClient struct { + generateAccessToken func(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) +} + +func (i *testIAMCredentialsClient) GenerateAccessToken(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) { + return i.generateAccessToken(ctx, req, opts...) +} + +var _ iamCredentialsClient = (*testIAMCredentialsClient)(nil) + +func makeTestCloudClient(client *testIAMCredentialsClient) cloudClientGCP { + return &cloudClientGCPImpl[*testIAMCredentialsClient]{getGCPIAMClient: func(ctx context.Context) (*testIAMCredentialsClient, error) { + return client, nil + }} +} + +func TestHandler_getToken(t *testing.T) { + mkConstConfig := func(val HandlerConfig) func(any) HandlerConfig { + return func(_ any) HandlerConfig { + return val + } + } + + tests := []struct { + name string + + initState func() any + + config func(state any) HandlerConfig + + wantToken *credentialspb.GenerateAccessTokenResponse + checkErr require.ErrorAssertionFunc + checkState func(require.TestingT, any) + }{ + { + name: "base case", + config: mkConstConfig(HandlerConfig{ + cloudClientGCP: makeTestCloudClient(&testIAMCredentialsClient{ + generateAccessToken: func(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) { + if req.GetName() != "projects/-/serviceAccounts/MY_ACCOUNT" { + return nil, trace.BadParameter("wrong serviceAccount, expected %q got %q", "projects/-/serviceAccounts/MY_ACCOUNT", req.GetName()) + } + if !assert.ObjectsAreEqual(req.GetScope(), defaultScopeList) { + return nil, trace.BadParameter("wrong scopes") + } + return &credentialspb.GenerateAccessTokenResponse{AccessToken: "ok"}, nil + }, + }), + }), + wantToken: &credentialspb.GenerateAccessTokenResponse{AccessToken: "ok"}, + checkErr: require.NoError, + }, + { + name: "timeout", + initState: func() any { + return clockwork.NewFakeClockAt(time.Date(2023, 1, 1, 12, 00, 00, 000, time.UTC)) + }, + config: func(state any) HandlerConfig { + return HandlerConfig{ + Clock: state.(clockwork.FakeClock).(clockwork.Clock), + cloudClientGCP: makeTestCloudClient(&testIAMCredentialsClient{ + generateAccessToken: func(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) { + clock := state.(clockwork.FakeClock) + clock.Advance(getTokenTimeout) + + clock.Sleep(getTokenTimeout * 2) + return nil, trace.BadParameter("bad param foo") + }, + }), + } + }, + checkErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "timeout waiting for access token for 5s") + require.ErrorIs(t, err, context.DeadlineExceeded) + }, + }, + { + name: "non-timeout error", + config: mkConstConfig(HandlerConfig{ + cloudClientGCP: makeTestCloudClient(&testIAMCredentialsClient{ + generateAccessToken: func(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) { + return nil, trace.BadParameter("bad param foo") + }, + }), + }), + checkErr: func(t require.TestingT, err error, i ...interface{}) { + require.ErrorContains(t, err, "bad param foo") + require.True(t, trace.IsBadParameter(err)) + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var state any + if tt.initState != nil { + state = tt.initState() + } + + ctx := context.Background() + + fwd, err := newGCPHandler(ctx, tt.config(state)) + require.NoError(t, err) + + token, err := fwd.getToken(ctx, "MY_ACCOUNT") + require.Equal(t, tt.wantToken, token) + tt.checkErr(t, err) + + if tt.checkState != nil { + tt.checkState(t, state) + } + }) + } +} + +func TestHandler_getToken_cache(t *testing.T) { + ctx := context.Background() + + clock := clockwork.NewFakeClock() + + calls := 0 + fwd, err := newGCPHandler(ctx, HandlerConfig{ + Clock: clock, + cloudClientGCP: makeTestCloudClient(&testIAMCredentialsClient{ + generateAccessToken: func(ctx context.Context, req *credentialspb.GenerateAccessTokenRequest, opts ...gax.CallOption) (*credentialspb.GenerateAccessTokenResponse, error) { + calls++ + return &credentialspb.GenerateAccessTokenResponse{AccessToken: "ok"}, nil + }, + }), + }) + require.NoError(t, err) + + // first call goes through + _, err = fwd.getToken(ctx, "") + require.NoError(t, err) + require.Equal(t, 1, calls) + + // second call is cached + _, err = fwd.getToken(ctx, "") + require.NoError(t, err) + require.Equal(t, 1, calls) + + // advance past cache expiry + clock.Advance(time.Second * 60 * 2) + + // third call goes through + _, err = fwd.getToken(ctx, "") + require.NoError(t, err) + require.Equal(t, 2, calls) +} diff --git a/lib/srv/app/server.go b/lib/srv/app/server.go index 9df1838e2d54a..9a0f4ae3bf6ca 100644 --- a/lib/srv/app/server.go +++ b/lib/srv/app/server.go @@ -48,6 +48,7 @@ import ( appaws "github.com/gravitational/teleport/lib/srv/app/aws" appazure "github.com/gravitational/teleport/lib/srv/app/azure" "github.com/gravitational/teleport/lib/srv/app/common" + appgcp "github.com/gravitational/teleport/lib/srv/app/gcp" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" awsutils "github.com/gravitational/teleport/lib/utils/aws" @@ -213,6 +214,7 @@ type Server struct { awsHandler http.Handler azureHandler http.Handler + gcpHandler http.Handler // watcher monitors changes to application resources. watcher *services.AppWatcher @@ -279,6 +281,11 @@ func New(ctx context.Context, c *Config) (*Server, error) { return nil, trace.Wrap(err) } + gcpHandler, err := appgcp.NewGCPHandler(closeContext, appgcp.HandlerConfig{}) + if err != nil { + return nil, trace.Wrap(err) + } + s := &Server{ c: c, log: logrus.WithFields(logrus.Fields{ @@ -290,6 +297,7 @@ func New(ctx context.Context, c *Config) (*Server, error) { connAuth: make(map[net.Conn]error), awsHandler: awsHandler, azureHandler: azureHandler, + gcpHandler: gcpHandler, monitoredApps: monitoredApps{ static: c.Apps, }, @@ -812,7 +820,10 @@ func (s *Server) serveHTTP(w http.ResponseWriter, r *http.Request) error { return s.serveAWSWebConsole(w, r, &identity, app) case app.IsAzureCloud(): - return s.serveSession(w, r, &identity, app, s.withAzureForwarder) + return s.serveSession(w, r, &identity, app, s.withAzureHandler) + + case app.IsGCP(): + return s.serveSession(w, r, &identity, app, s.withGCPHandler) default: return s.serveSession(w, r, &identity, app, s.withJWTTokenForwarder) @@ -930,6 +941,14 @@ func (s *Server) authorizeContext(ctx context.Context) (*auth.Context, types.App }) } + // When accessing GCP API, check permissions to assume + // requested GCP service account as well. + if app.IsGCP() { + matchers = append(matchers, &services.GCPServiceAccountMatcher{ + ServiceAccount: identity.RouteToApp.GCPServiceAccount, + }) + } + mfaParams := authContext.MFAParams(ap.GetRequireMFAType()) err = authContext.Checker.CheckAccess( app, diff --git a/lib/srv/app/session.go b/lib/srv/app/session.go index 24215c4492063..9215260c0270f 100644 --- a/lib/srv/app/session.go +++ b/lib/srv/app/session.go @@ -205,11 +205,16 @@ func (s *Server) withAWSSigner(_ context.Context, sess *sessionChunk, _ *tlsca.I return nil } -func (s *Server) withAzureForwarder(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { +func (s *Server) withAzureHandler(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { sess.handler = s.azureHandler return nil } +func (s *Server) withGCPHandler(ctx context.Context, sess *sessionChunk, identity *tlsca.Identity, app types.Application) error { + sess.handler = s.gcpHandler + return nil +} + // acquire() increments in-flight request count by 1. // It is supposed to be paired with a `release()` call, // after the chunk is done with for the individual request