diff --git a/lib/kube/proxy/forwarder.go b/lib/kube/proxy/forwarder.go index 31049202794e1..fc9585339dd6f 100644 --- a/lib/kube/proxy/forwarder.go +++ b/lib/kube/proxy/forwarder.go @@ -444,6 +444,21 @@ func (f *Forwarder) withAuthStd(handler handlerWithAuthFuncStd) http.HandlerFunc }, f.formatResponseError) } +// acquireConnectionLockWithIdentity acquires a connection lock under a given identity. +func (f *Forwarder) acquireConnectionLockWithIdentity(ctx context.Context, identity *authContext) error { + user := identity.Identity.GetIdentity().Username + roles, err := getRolesByName(f, identity.Identity.GetIdentity().Groups) + if err != nil { + return trace.Wrap(err) + } + + if err := f.acquireConnectionLock(ctx, user, roles); err != nil { + return trace.Wrap(err) + } + + return nil +} + func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle { return httplib.MakeHandlerWithErrorWriter(func(w http.ResponseWriter, req *http.Request, p httprouter.Params) (interface{}, error) { authContext, err := f.authenticate(req) @@ -453,16 +468,10 @@ func (f *Forwarder) withAuth(handler handlerWithAuthFunc) httprouter.Handle { if err := f.authorize(req.Context(), authContext); err != nil { return nil, trace.Wrap(err) } - - user := authContext.Identity.GetIdentity().Username - roles, err := getRolesByName(f, authContext.Identity.GetIdentity().Groups) + err = f.acquireConnectionLockWithIdentity(req.Context(), authContext) if err != nil { return nil, trace.Wrap(err) } - - if err := f.AcquireConnectionLock(req.Context(), user, roles); err != nil { - return nil, trace.Wrap(err) - } return handler(authContext, w, req, p) }, f.formatResponseError) } @@ -477,6 +486,10 @@ func (f *Forwarder) withAuthPassthrough(handler handlerWithAuthFunc) httprouter. return nil, trace.Wrap(err) } } + err = f.acquireConnectionLockWithIdentity(req.Context(), authContext) + if err != nil { + return nil, trace.Wrap(err) + } return handler(authContext, w, req, p) }, f.formatResponseError) } @@ -914,10 +927,10 @@ func wsProxy(wsSource *websocket.Conn, wsTarget *websocket.Conn) error { return trace.Wrap(err) } -// AcquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. +// acquireConnectionLock acquires a semaphore used to limit connections to the Kubernetes agent. // The semaphore is releasted when the request is returned/connection is closed. // Returns an error if a semaphore could not be acquired. -func (f *Forwarder) AcquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error { +func (f *Forwarder) acquireConnectionLock(ctx context.Context, user string, roles services.RoleSet) error { maxConnections := roles.MaxKubernetesConnections() if maxConnections == 0 { return nil diff --git a/lib/kube/proxy/forwarder_test.go b/lib/kube/proxy/forwarder_test.go index 51a3d9198d168..e0376c6bc1c37 100644 --- a/lib/kube/proxy/forwarder_test.go +++ b/lib/kube/proxy/forwarder_test.go @@ -1068,7 +1068,8 @@ func newTestForwarder(ctx context.Context, cfg ForwarderConfig) *Forwarder { type mockSemaphoreClient struct { auth.ClientI - sem types.Semaphores + sem types.Semaphores + roles map[string]types.Role } func (m *mockSemaphoreClient) AcquireSemaphore(ctx context.Context, params types.AcquireSemaphoreRequest) (*types.SemaphoreLease, error) { @@ -1079,6 +1080,15 @@ func (m *mockSemaphoreClient) CancelSemaphoreLease(ctx context.Context, lease ty return m.sem.CancelSemaphoreLease(ctx, lease) } +func (m *mockSemaphoreClient) GetRole(ctx context.Context, name string) (types.Role, error) { + role, ok := m.roles[name] + if !ok { + return nil, trace.NotFound("role %q not found", name) + } + + return role, nil +} + func TestKubernetesConnectionLimit(t *testing.T) { ctx, cancel := context.WithCancel(context.Background()) defer cancel() @@ -1131,13 +1141,28 @@ func TestKubernetesConnectionLimit(t *testing.T) { require.NoError(t, err) sem := local.NewPresenceService(backend) - client := &mockSemaphoreClient{sem: sem} + client := &mockSemaphoreClient{ + sem: sem, + roles: map[string]types.Role{testCase.role.GetName(): testCase.role}, + } + forwarder := newTestForwarder(ctx, ForwarderConfig{ - AuthClient: client, + AuthClient: client, + CachingAuthClient: client, }) + identity := &authContext{ + Context: auth.Context{ + User: user, + Identity: auth.WrapIdentity(tlsca.Identity{ + Username: user.GetName(), + Groups: []string{testCase.role.GetName()}, + }), + }, + } + for i := 0; i < testCase.connections; i++ { - err = forwarder.AcquireConnectionLock(ctx, user.GetName(), services.NewRoleSet(testCase.role)) + err = forwarder.acquireConnectionLockWithIdentity(ctx, identity) if i == testCase.connections-1 { testCase.assert(t, err) }