diff --git a/api/breaker/round_tripper.go b/api/breaker/round_tripper.go index 5eec8c61b1520..17b7b9eeb415b 100644 --- a/api/breaker/round_tripper.go +++ b/api/breaker/round_tripper.go @@ -59,3 +59,8 @@ func (t *RoundTripper) RoundTrip(request *http.Request) (*http.Response, error) return v.(*http.Response), err } + +// Unwrap returns the inner round tripper. +func (t *RoundTripper) Unwrap() http.RoundTripper { + return t.tripper +} diff --git a/api/observability/tracing/http/http.go b/api/observability/tracing/http/http.go index 8d8eeca7bfe12..28567fffddbfc 100644 --- a/api/observability/tracing/http/http.go +++ b/api/observability/tracing/http/http.go @@ -15,6 +15,7 @@ package http import ( + "net/http" nethttp "net/http" "go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp" @@ -43,39 +44,37 @@ func HandlerFormatter(operation string, r *nethttp.Request) string { // https://github.com/open-telemetry/opentelemetry-go-contrib/issues/3543. // Once the issue is resolved the wrapper may be discarded. func NewTransport(rt nethttp.RoundTripper) nethttp.RoundTripper { - return enforceCloseIdleConnections( - otelhttp.NewTransport(rt, - otelhttp.WithSpanNameFormatter(TransportFormatter), - ), rt) + return &roundTripWrapper{ + RoundTripper: otelhttp.NewTransport(rt, otelhttp.WithSpanNameFormatter(TransportFormatter)), + inner: rt, + } +} + +type closeIdler interface { + CloseIdleConnections() +} + +type roundTripWrapper struct { + nethttp.RoundTripper + inner nethttp.RoundTripper +} + +// Unwrap returns the inner round tripper. +func (w *roundTripWrapper) Unwrap() http.RoundTripper { + return w.inner } -// enforceCloseIdleConnections ensures that the returned [nethttp.RoundTripper] +// CloseIdleConnections ensures that the returned [nethttp.RoundTripper] // has a CloseIdleConnections method. Since [otelhttp.Transport] does not implement // this any [nethttp.Client.CloseIdleConnections] calls result in a noop instead // of forwarding the request onto its wrapped [nethttp.RoundTripper]. // // DELETE WHEN https://github.com/open-telemetry/opentelemetry-go-contrib/issues/3543 // has been resolved. -func enforceCloseIdleConnections(wrapper, wrapped nethttp.RoundTripper) nethttp.RoundTripper { - type closeIdler interface { - CloseIdleConnections() +func (w *roundTripWrapper) CloseIdleConnections() { + if c, ok := w.RoundTripper.(closeIdler); ok { + c.CloseIdleConnections() + } else if c, ok := w.inner.(closeIdler); ok { + c.CloseIdleConnections() } - - type unwrapper struct { - nethttp.RoundTripper - closeIdler - } - - if _, ok := wrapper.(closeIdler); ok { - return wrapper - } - - if c, ok := wrapped.(closeIdler); ok { - return &unwrapper{ - RoundTripper: wrapper, - closeIdler: c, - } - } - - return wrapper } diff --git a/e b/e index f64397a5610b2..76f403a8e21bd 160000 --- a/e +++ b/e @@ -1 +1 @@ -Subproject commit f64397a5610b2bde64f8598fc9b4ed43f5065f42 +Subproject commit 76f403a8e21bde4d6f7170b560e2acf7345b2158 diff --git a/lib/auth/auth.go b/lib/auth/auth.go index aa0a702edb281..7d44366c87af8 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -4719,6 +4719,34 @@ func (a *Server) GetLicense(ctx context.Context) (string, error) { return fmt.Sprintf("%s%s", a.license.CertPEM, a.license.KeyPEM), nil } +// GetHeadlessAuthentication returns a headless authentication from the backend by name. +// If it does not yet exist, a stub will be created to signal the login process to upsert +// login details. This method will wait for the updated headless authentication and return it. +func (a *Server) GetHeadlessAuthentication(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { + // Try to create a stub if it doesn't already exist, then wait for full login details. + if _, err := a.Services.CreateHeadlessAuthenticationStub(ctx, name); err != nil && !trace.IsAlreadyExists(err) { + return nil, trace.Wrap(err) + } + + // wait for the headless authentication to be updated with valid login details + // by the login process. If the headless authentication is already updated, + // Wait will return it immediately. + waitCtx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() + + headlessAuthn, err := a.headlessAuthenticationWatcher.Wait(waitCtx, name, func(ha *types.HeadlessAuthentication) (bool, error) { + return services.ValidateHeadlessAuthentication(ha) == nil, nil + }) + return headlessAuthn, trace.Wrap(err) +} + +// CompareAndSwapHeadlessAuthentication performs a compare +// and swap replacement on a headless authentication resource. +func (a *Server) CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) { + headlessAuthn, err := a.Services.CompareAndSwapHeadlessAuthentication(ctx, old, new) + return headlessAuthn, trace.Wrap(err) +} + // authKeepAliver is a keep aliver using auth server directly type authKeepAliver struct { sync.RWMutex diff --git a/lib/auth/auth_login_test.go b/lib/auth/auth_login_test.go index dc41aba905e67..57b214cb54257 100644 --- a/lib/auth/auth_login_test.go +++ b/lib/auth/auth_login_test.go @@ -28,6 +28,7 @@ import ( "github.com/gravitational/teleport/lib/auth/mocku2f" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/services" ) func TestServer_CreateAuthenticateChallenge_authPreference(t *testing.T) { @@ -701,6 +702,114 @@ func TestServer_Authenticate_nonPasswordlessRequiresUsername(t *testing.T) { } } +func TestServer_Authenticate_headless(t *testing.T) { + t.Parallel() + + ctx := context.Background() + headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + const timeout = time.Millisecond * 200 + + type updateHeadlessAuthnFn func(*types.HeadlessAuthentication, *types.MFADevice) + updateHeadlessAuthnInGoroutine := func(ctx context.Context, srv *TestTLSServer, mfa *types.MFADevice, update updateHeadlessAuthnFn) chan error { + errC := make(chan error) + go func() { + defer close(errC) + + headlessAuthn, err := srv.Auth().GetHeadlessAuthentication(ctx, headlessID) + if err != nil { + errC <- err + return + } + + // create a shallow copy and update for the compare and swap below. + replaceHeadlessAuthn := *headlessAuthn + update(&replaceHeadlessAuthn, mfa) + + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) + if err != nil { + errC <- err + return + } + }() + return errC + } + + for _, tc := range []struct { + name string + update updateHeadlessAuthnFn + checkErr require.ErrorAssertionFunc + }{ + { + name: "OK approved", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + ha.MfaDevice = mfa + }, + checkErr: require.NoError, + }, { + name: "NOK approved without MFA", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + }, + checkErr: require.Error, + }, { + name: "NOK user mismatch", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED + ha.MfaDevice = mfa + ha.User = "other-user" + }, + checkErr: require.Error, + }, { + name: "NOK denied", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) { + ha.State = types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED + }, + checkErr: require.Error, + }, { + name: "NOK timeout", + update: func(ha *types.HeadlessAuthentication, mfa *types.MFADevice) {}, + checkErr: require.Error, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + + srv := newTestTLSServer(t) + proxyClient, err := srv.NewClient(TestBuiltin(types.RoleProxy)) + require.NoError(t, err) + + // We don't mind about the specifics of the configuration, as long as we have + // a user and TOTP/WebAuthn devices. + mfa := configureForMFA(t, srv) + username := mfa.User + + t.Cleanup(func() { + srv.Auth().DeleteHeadlessAuthentication(ctx, headlessID) + }) + + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + + errC := updateHeadlessAuthnInGoroutine(ctx, srv, mfa.WebDev.MFA, tc.update) + _, err = proxyClient.AuthenticateSSHUser(ctx, AuthenticateSSHRequest{ + AuthenticateUserRequest: AuthenticateUserRequest{ + Username: username, + PublicKey: []byte(sshPubKey), + HeadlessAuthenticationID: headlessID, + ClientMetadata: &ForwardedClientMetadata{ + RemoteAddr: "0.0.0.0", + }, + }, + TTL: defaults.CallbackTimeout, + }) + tc.checkErr(t, err) + require.NoError(t, <-errC) + }) + } +} + type configureMFAResp struct { User, Password string TOTPDev, WebDev *TestDevice diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index afe4a23dd67f4..df4f41d46b267 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -23,6 +23,7 @@ import ( "strings" "time" + "github.com/gravitational/roundtrip" "github.com/gravitational/trace" "github.com/sirupsen/logrus" collectortracev1 "go.opentelemetry.io/proto/otlp/collector/trace/v1" @@ -1119,30 +1120,6 @@ func (a *ServerWithRoles) UpsertNode(ctx context.Context, s types.Server) (*type return a.authServer.UpsertNode(ctx, s) } -// DELETE IN: 5.1.0 -// -// This logic has moved to KeepAliveServer. -func (a *ServerWithRoles) KeepAliveNode(ctx context.Context, handle types.KeepAlive) error { - if !a.hasBuiltinRole(types.RoleNode) { - return trace.AccessDenied("[10] access denied") - } - clusterName, err := a.GetDomainName(ctx) - if err != nil { - return trace.Wrap(err) - } - serverName, err := ExtractHostID(a.context.User.GetName(), clusterName) - if err != nil { - return trace.AccessDenied("[10] access denied") - } - if serverName != handle.Name { - return trace.AccessDenied("[10] access denied") - } - if err := a.action(apidefaults.Namespace, types.KindNode, types.VerbUpdate); err != nil { - return trace.Wrap(err) - } - return a.authServer.KeepAliveNode(ctx, handle) -} - // KeepAliveServer updates expiry time of a server resource. func (a *ServerWithRoles) KeepAliveServer(ctx context.Context, handle types.KeepAlive) error { clusterName, err := a.GetDomainName(ctx) @@ -5732,14 +5709,86 @@ func (a *ServerWithRoles) DeleteAllUserGroups(ctx context.Context) error { // GetHeadlessAuthentication retrieves a headless authentication by id. func (a *ServerWithRoles) GetHeadlessAuthentication(ctx context.Context, id string) (*types.HeadlessAuthentication, error) { - // TODO (joerger): Add implementation - follow up PR - return nil, trace.NotImplemented("GetHeadlessAuthentication is not implemented") + // GetHeadlessAuthentication will wait for the headless details + // if they don't yet exist in the backend. + ctx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() + + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) + if err != nil { + return nil, trace.Wrap(err) + } + + // Only users can get their own headless authentication requests. + if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { + // This method would usually time out above if the headless authentication + // does not exist, so we mimick this behavior here for users without access. + <-ctx.Done() + return nil, trace.Wrap(ctx.Err()) + } + + return headlessAuthn, nil } // UpdateHeadlessAuthenticationState updates a headless authentication state. func (a *ServerWithRoles) UpdateHeadlessAuthenticationState(ctx context.Context, id string, state types.HeadlessAuthenticationState, mfaResp *proto.MFAAuthenticateResponse) error { - // TODO (joerger): Add implementation - follow up PR - return trace.NotImplemented("UpdateHeadlessAuthenticationState is not implemented") + // GetHeadlessAuthentication will wait for the headless details + // if they don't yet exist in the backend. + ctx, cancel := context.WithTimeout(ctx, defaults.HTTPRequestTimeout) + defer cancel() + + headlessAuthn, err := a.authServer.GetHeadlessAuthentication(ctx, id) + if err != nil { + return trace.Wrap(err) + } + + // Only users can update their own headless authentication requests. + if !hasLocalUserRole(a.context) || headlessAuthn.User != a.context.User.GetName() { + // This method would usually time out above if the headless authentication + // does not exist, so we mimick this behavior here for users without access. + <-ctx.Done() + return trace.Wrap(ctx.Err()) + } + + if headlessAuthn.State != types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_PENDING { + return trace.AccessDenied("cannot update a headless authentication state from a non-pending state") + } + + // Shallow copy headless authn for compare and swap below. + replaceHeadlessAuthn := *headlessAuthn + replaceHeadlessAuthn.State = state + + switch state { + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: + // The user must authenticate with MFA to change the state to approved. + if mfaResp == nil { + return trace.BadParameter("expected MFA auth challenge response") + } + + // Only WebAuthn is supported in headless login flow for superior phishing prevention. + if _, ok := mfaResp.Response.(*proto.MFAAuthenticateResponse_Webauthn); !ok { + return trace.BadParameter("expected WebAuthn challenge response, but got %T", mfaResp.Response) + } + + mfaDevice, _, err := a.authServer.validateMFAAuthResponse(ctx, mfaResp, headlessAuthn.User, false /* passwordless */) + if err != nil { + return trace.Wrap(err) + } + + replaceHeadlessAuthn.MfaDevice = mfaDevice + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED: + // continue to compare and swap without MFA. + default: + return trace.AccessDenied("cannot update a headless authentication state to %v", state.String()) + } + + _, err = a.authServer.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthn, &replaceHeadlessAuthn) + return trace.Wrap(err) +} + +// CloneHTTPClient creates a new HTTP client with the same configuration. +func (a *ServerWithRoles) CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) { + return nil, trace.NotImplemented("not implemented") } // NewAdminAuthServer returns auth server authorized as admin, diff --git a/lib/auth/auth_with_roles_test.go b/lib/auth/auth_with_roles_test.go index 535afafa56bcb..3094db9a9e4c2 100644 --- a/lib/auth/auth_with_roles_test.go +++ b/lib/auth/auth_with_roles_test.go @@ -38,6 +38,7 @@ import ( "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/types/installers" + "github.com/gravitational/teleport/api/types/webauthn" "github.com/gravitational/teleport/api/types/wrappers" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/sshutils" @@ -4214,3 +4215,234 @@ func TestUnimplementedClients(t *testing.T) { require.True(t, trace.IsNotImplemented(err), err) }) } + +// getTestHeadlessAuthenticationID returns the headless authentication resource +// used across headless authentication tests. +func getTestHeadlessAuthn(t *testing.T, user string) *types.HeadlessAuthentication { + headlessID := services.NewHeadlessAuthenticationID([]byte(sshPubKey)) + headlessAuthn := &types.HeadlessAuthentication{ + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Name: headlessID, + }, + }, + User: user, + PublicKey: []byte(sshPubKey), + ClientIpAddress: "0.0.0.0", + } + headlessAuthn.SetExpiry(time.Now().Add(time.Minute)) + + err := headlessAuthn.CheckAndSetDefaults() + require.NoError(t, err) + + return headlessAuthn +} + +func TestGetHeadlessAuthentication(t *testing.T) { + ctx := context.Background() + username := "teleport-user" + otherUsername := "other-user" + + for _, tc := range []struct { + name string + headlessID string + identity TestIdentity + assertError require.ErrorAssertionFunc + expectedHeadlessAuthn *types.HeadlessAuthentication + }{ + { + name: "OK same user", + identity: TestUser(username), + assertError: require.NoError, + }, { + name: "NOK not found", + headlessID: uuid.NewString(), + identity: TestUser(username), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user", + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin", + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + + srv := newTestTLSServer(t) + _, _, err := CreateUserAndRole(srv.Auth(), username, nil, nil) + require.NoError(t, err) + _, _, err = CreateUserAndRole(srv.Auth(), otherUsername, nil, nil) + require.NoError(t, err) + + // create headless authn + headlessAuthn := getTestHeadlessAuthn(t, username) + stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessAuthn.GetName()) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + require.NoError(t, err) + + client, err := srv.NewClient(tc.identity) + require.NoError(t, err) + + ctx, cancel := context.WithTimeout(ctx, time.Millisecond*100) + defer cancel() + + // default to same headlessAuthn + if tc.headlessID == "" { + tc.headlessID = headlessAuthn.GetName() + } + + retrievedHeadlessAuthn, err := client.GetHeadlessAuthentication(ctx, tc.headlessID) + tc.assertError(t, err) + if err == nil { + require.Equal(t, headlessAuthn, retrievedHeadlessAuthn) + } + }) + } +} + +func TestUpdateHeadlessAuthenticationState(t *testing.T) { + ctx := context.Background() + otherUsername := "other-user" + + for _, tc := range []struct { + name string + // defaults to the mfa identity tied to the headless authentication created + identity TestIdentity + // defaults to id of the headless authentication created + headlessID string + state types.HeadlessAuthenticationState + withMFA bool + assertError require.ErrorAssertionFunc + }{ + { + name: "OK same user denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + assertError: require.NoError, + }, { + name: "OK same user approved with mfa", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + withMFA: true, + assertError: require.NoError, + }, { + name: "NOK same user approved without mfa", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + withMFA: false, + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.True(t, trace.IsAccessDenied(err), "expected access denied error but got: %v", err) + }, + }, { + name: "NOK not found", + headlessID: uuid.NewString(), + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK different user approved", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + identity: TestUser(otherUsername), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin denied", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED, + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, { + name: "NOK admin approved", + state: types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED, + identity: TestAdmin(), + assertError: func(t require.TestingT, err error, i ...interface{}) { + require.Error(t, err) + require.ErrorContains(t, err, context.DeadlineExceeded.Error(), "expected context deadline error but got: %v", err) + }, + }, + } { + t.Run(tc.name, func(t *testing.T) { + tc := tc + t.Parallel() + + srv := newTestTLSServer(t) + mfa := configureForMFA(t, srv) + + _, _, err := CreateUserAndRole(srv.Auth(), otherUsername, nil, nil) + require.NoError(t, err) + + // create headless authn + headlessAuthn := getTestHeadlessAuthn(t, mfa.User) + stub, err := srv.Auth().CreateHeadlessAuthenticationStub(ctx, headlessAuthn.GetName()) + require.NoError(t, err) + _, err = srv.Auth().CompareAndSwapHeadlessAuthentication(ctx, stub, headlessAuthn) + require.NoError(t, err) + + // default to mfa user + if tc.identity.I == nil { + tc.identity = TestUser(mfa.User) + } + + client, err := srv.NewClient(tc.identity) + require.NoError(t, err) + + resp := &proto.MFAAuthenticateResponse{ + Response: &proto.MFAAuthenticateResponse_Webauthn{ + Webauthn: &webauthn.CredentialAssertionResponse{ + Type: "bad response", + }, + }, + } + if tc.withMFA { + client, err := srv.NewClient(TestUser(mfa.User)) + require.NoError(t, err) + + challenge, err := client.CreateAuthenticateChallenge(ctx, &proto.CreateAuthenticateChallengeRequest{ + Request: &proto.CreateAuthenticateChallengeRequest_ContextUser{}, + }) + require.NoError(t, err) + + resp, err = mfa.WebDev.SolveAuthn(challenge) + require.NoError(t, err) + } + + // default to same headlessAuthn + if tc.headlessID == "" { + tc.headlessID = headlessAuthn.GetName() + } + + ctx, cancel := context.WithTimeout(ctx, time.Second) + defer cancel() + + err = client.UpdateHeadlessAuthenticationState(ctx, tc.headlessID, tc.state, resp) + tc.assertError(t, err) + }) + } +} diff --git a/lib/auth/clt.go b/lib/auth/clt.go index e7c3441ce8558..3eb42be3a0bc7 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -18,39 +18,25 @@ package auth import ( "context" - "crypto/tls" - "encoding/json" - "fmt" "net" - "net/http" "net/url" - "strconv" - "strings" "time" "github.com/gravitational/roundtrip" "github.com/gravitational/trace" - "github.com/gravitational/teleport" - "github.com/gravitational/teleport/api/breaker" "github.com/gravitational/teleport/api/client" "github.com/gravitational/teleport/api/client/okta" "github.com/gravitational/teleport/api/client/proto" - "github.com/gravitational/teleport/api/constants" - apidefaults "github.com/gravitational/teleport/api/defaults" devicepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/devicetrust/v1" loginrulepb "github.com/gravitational/teleport/api/gen/proto/go/teleport/loginrule/v1" pluginspb "github.com/gravitational/teleport/api/gen/proto/go/teleport/plugins/v1" samlidppb "github.com/gravitational/teleport/api/gen/proto/go/teleport/samlidp/v1" - tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" "github.com/gravitational/teleport/api/types" apievents "github.com/gravitational/teleport/api/types/events" - "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" - "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/session" - "github.com/gravitational/teleport/lib/utils" ) const ( @@ -62,6 +48,9 @@ const ( MissingNamespaceError = "missing required parameter: namespace" ) +// APIClient is aliased here so that it can be embedded in Client. +type APIClient = client.Client + // Client is the Auth API client. It works by connecting to auth servers // via gRPC and HTTP. // @@ -103,43 +92,14 @@ func NewClient(cfg client.Config, params ...roundtrip.ClientParam) (*Client, err } // apiClient configures the tls.Config, so we clone it and reuse it for http. - tlsConfig := apiClient.Config().Clone() - httpClient, err := NewHTTPClient(cfg, tlsConfig, params...) - if err != nil { - return nil, trace.Wrap(err) - } - - return &Client{ - APIClient: apiClient, - HTTPClient: httpClient, - }, nil -} - -// APIClient is aliased here so that it can be embedded in Client. -type APIClient = client.Client - -// HTTPClient is a teleport HTTP API client. -type HTTPClient struct { - roundtrip.Client - // transport defines the methods by which the client can reach the server. - transport *http.Transport - // TLS holds the TLS config for the http client. - tls *tls.Config -} - -// NewHTTPClient creates a new HTTP client with TLS authentication and the given dialer. -func NewHTTPClient(cfg client.Config, tls *tls.Config, params ...roundtrip.ClientParam) (*HTTPClient, error) { - if err := cfg.CheckAndSetDefaults(); err != nil { - return nil, err - } - - dialer := cfg.Dialer - if dialer == nil { + httpTLS := apiClient.Config().Clone() + httpDialer := cfg.Dialer + if httpDialer == nil { if len(cfg.Addrs) == 0 { return nil, trace.BadParameter("no addresses to dial") } - contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(tls)) - dialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { + contextDialer := client.NewDialer(cfg.Context, cfg.KeepAlivePeriod, cfg.DialTimeout, client.WithTLSConfig(httpTLS)) + httpDialer = client.ContextDialerFunc(func(ctx context.Context, network, _ string) (conn net.Conn, err error) { for _, addr := range cfg.Addrs { conn, err = contextDialer.DialContext(ctx, network, addr) if err == nil { @@ -150,148 +110,22 @@ func NewHTTPClient(cfg client.Config, tls *tls.Config, params ...roundtrip.Clien return nil, err }) } - - // Set the next protocol. This is needed due to the Auth Server using a - // multiplexer for protocol detection. Unless next protocol is specified - // it will attempt to upgrade to HTTP2 and at that point there is no way - // to distinguish between HTTP2/JSON or GPRC. - tls.NextProtos = []string{teleport.HTTPNextProtoTLS} - // Configure ALPN SNI direct dial TLS routing information used by ALPN SNI proxy in order to - // dial auth service without using SSH tunnels. - tls = client.ConfigureALPN(tls, cfg.ALPNSNIAuthDialClusterName) - - transport := &http.Transport{ - // notice that below roundtrip.Client is passed - // teleport.APIDomain as an address for the API server, this is - // to make sure client verifies the DNS name of the API server and - // custom DialContext overrides this DNS name to the real address. - // In addition this dialer tries multiple addresses if provided - DialContext: dialer.DialContext, - ResponseHeaderTimeout: apidefaults.DefaultIOTimeout, - TLSClientConfig: tls, - - // Increase the size of the connection pool. This substantially improves the - // performance of Teleport under load as it reduces the number of TLS - // handshakes performed. - MaxIdleConns: defaults.HTTPMaxIdleConns, - MaxIdleConnsPerHost: defaults.HTTPMaxIdleConnsPerHost, - - // Limit the total number of connections to the Auth Server. Some hosts allow a low - // number of connections per process (ulimit) to a host. This is a problem for - // enhanced session recording auditing which emits so many events to the - // Audit Log (using the Auth Client) that the connection pool often does not - // have a free connection to return, so just opens a new one. This quickly - // leads to hitting the OS limit and the client returning out of file - // descriptors error. - MaxConnsPerHost: defaults.HTTPMaxConnsPerHost, - - // IdleConnTimeout defines the maximum amount of time before idle connections - // are closed. Leaving this unset will lead to connections open forever and - // will cause memory leaks in a long running process. - IdleConnTimeout: defaults.HTTPIdleTimeout, - } - - cb, err := breaker.New(cfg.CircuitBreakerConfig) - if err != nil { - return nil, trace.Wrap(err) + httpClientCfg := &HTTPClientConfig{ + TLS: httpTLS, + Dialer: httpDialer, + ALPNSNIAuthDialClusterName: cfg.ALPNSNIAuthDialClusterName, } - - clientParams := append( - []roundtrip.ClientParam{ - roundtrip.HTTPClient(&http.Client{ - Timeout: defaults.HTTPRequestTimeout, - Transport: tracehttp.NewTransport(breaker.NewRoundTripper(cb, transport)), - }), - roundtrip.SanitizerEnabled(true), - }, - params..., - ) - - // Since the client uses a custom dialer and SNI is used for TLS handshake, the address - // used here is arbitrary as it just needs to be set to pass http request validation. - httpClient, err := roundtrip.NewClient("https://"+constants.APIDomain, CurrentVersion, clientParams...) + httpClient, err := NewHTTPClient(httpClientCfg, params...) if err != nil { return nil, trace.Wrap(err) } - return &HTTPClient{ - Client: *httpClient, - transport: transport, - tls: tls, + return &Client{ + APIClient: apiClient, + HTTPClient: httpClient, }, nil } -// Close closes the HTTP client connection to the auth server. -func (c *HTTPClient) Close() { - c.transport.CloseIdleConnections() -} - -// TLSConfig returns the HTTP client's TLS config. -func (c *HTTPClient) TLSConfig() *tls.Config { - return c.tls -} - -// GetTransport returns the HTTP client's transport. -func (c *HTTPClient) GetTransport() *http.Transport { - return c.transport -} - -// ClientTimeout sets idle and dial timeouts of the HTTP transport -// used by the client. -func ClientTimeout(timeout time.Duration) roundtrip.ClientParam { - return func(c *roundtrip.Client) error { - transport, ok := (c.HTTPClient().Transport).(*http.Transport) - if !ok { - return nil - } - transport.IdleConnTimeout = timeout - transport.ResponseHeaderTimeout = timeout - return nil - } -} - -// PostJSON is a generic method that issues http POST request to the server -func (c *Client) PostJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PostJSON(ctx, endpoint, val)) -} - -// PutJSON is a generic method that issues http PUT request to the server -func (c *Client) PutJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PutJSON(ctx, endpoint, val)) -} - -// PostForm is a generic method that issues http POST request to the server -func (c *Client) PostForm(ctx context.Context, endpoint string, vals url.Values, files ...roundtrip.File) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.PostForm(ctx, endpoint, vals, files...)) -} - -// Get issues http GET request to the server -func (c *Client) Get(ctx context.Context, u string, params url.Values) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.Get(ctx, u, params)) -} - -// Delete issues http Delete Request to the server -func (c *Client) Delete(ctx context.Context, u string) (*roundtrip.Response, error) { - return httplib.ConvertResponse(c.Client.Delete(ctx, u)) -} - -// ProcessKubeCSR processes CSR request against Kubernetes CA, returns -// signed certificate if successful. -func (c *Client) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - out, err := c.PostJSON(context.TODO(), c.Endpoint("kube", "csr"), req) - if err != nil { - return nil, trace.Wrap(err) - } - var re KubeCSRResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - func (c *Client) Close() error { c.HTTPClient.Close() return c.APIClient.Close() @@ -302,28 +136,6 @@ func (c *Client) CreateCertAuthority(ctx context.Context, ca types.CertAuthority return trace.NotImplemented(notImplementedMessage) } -// RotateCertAuthority starts or restarts certificate authority rotation process. -func (c *Client) RotateCertAuthority(ctx context.Context, req RotateRequest) error { - _, err := c.PostJSON(ctx, c.Endpoint("authorities", string(req.Type), "rotate"), req) - return trace.Wrap(err) -} - -// RotateExternalCertAuthority rotates external certificate authority, -// this method is used to update only public keys and certificates of the -// the certificate authorities of trusted clusters. -func (c *Client) RotateExternalCertAuthority(ctx context.Context, ca types.CertAuthority) error { - if err := services.ValidateCertAuthority(ca); err != nil { - return trace.Wrap(err) - } - data, err := services.MarshalCertAuthority(ca) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), - &rotateExternalCertAuthorityRawReq{CA: data}) - return trace.Wrap(err) -} - // UpsertCertAuthority updates or inserts new cert authority func (c *Client) UpsertCertAuthority(ctx context.Context, ca types.CertAuthority) error { if err := services.ValidateCertAuthority(ca); err != nil { @@ -337,12 +149,7 @@ func (c *Client) UpsertCertAuthority(ctx context.Context, ca types.CertAuthority // Fallback to HTTP API // DELETE IN 14.0.0 case trace.IsNotImplemented(err): - data, err := services.MarshalCertAuthority(ca) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType())), - &upsertCertAuthorityRawReq{CA: data}) + err := c.HTTPClient.UpsertCertAuthority(ctx, ca) return trace.Wrap(err) default: return trace.Wrap(err) @@ -368,27 +175,8 @@ func (c *Client) GetCertAuthorities(ctx context.Context, caType types.CertAuthTy // Fallback to HTTP API // DELETE IN 14.0.0 case trace.IsNotImplemented(err): - resp, err := c.Get(ctx, c.Endpoint("authorities", string(caType)), url.Values{ - "load_keys": []string{fmt.Sprintf("%t", loadKeys)}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(resp.Bytes(), &items); err != nil { - return nil, err - } - cas := make([]types.CertAuthority, 0, len(items)) - for _, raw := range items { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - - cas = append(cas, ca) - } - - return cas, nil + cas, err := c.HTTPClient.GetCertAuthorities(ctx, caType, loadKeys) + return cas, trace.Wrap(err) default: return nil, trace.Wrap(err) } @@ -408,13 +196,7 @@ func (c *Client) GetCertAuthority(ctx context.Context, id types.CertAuthID, load // Fallback to HTTP API // DELETE IN 14.0.0 case trace.IsNotImplemented(err): - out, err := c.Get(ctx, c.Endpoint("authorities", string(id.Type), id.DomainName), url.Values{ - "load_keys": []string{fmt.Sprintf("%t", loadSigningKeys)}, - }) - if err != nil { - return nil, trace.Wrap(err) - } - ca, err := services.UnmarshalCertAuthority(out.Bytes()) + ca, err := c.HTTPClient.GetCertAuthority(ctx, id, loadSigningKeys) return ca, trace.Wrap(err) default: return nil, trace.Wrap(err) @@ -434,7 +216,7 @@ func (c *Client) DeleteCertAuthority(ctx context.Context, id types.CertAuthID) e // Fallback to HTTP API // DELETE IN 14.0.0 case trace.IsNotImplemented(err): - _, err := c.Delete(ctx, c.Endpoint("authorities", string(id.Type), id.DomainName)) + err = c.HTTPClient.DeleteCertAuthority(ctx, id) return trace.Wrap(err) default: return trace.Wrap(err) @@ -456,181 +238,21 @@ func (c *Client) UpdateUserCARoleMap(ctx context.Context, name string, roleMap t return trace.NotImplemented(notImplementedMessage) } -// RegisterUsingToken calls the auth service API to register a new node using a registration token -// which was previously issued via CreateToken/UpsertToken. -func (c *Client) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { - if err := req.CheckAndSetDefaults(); err != nil { - return nil, trace.Wrap(err) - } - out, err := c.PostJSON(ctx, c.Endpoint("tokens", "register"), req) - if err != nil { - return nil, trace.Wrap(err) - } - - var certs proto.Certs - if err := json.Unmarshal(out.Bytes(), &certs); err != nil { - return nil, trace.Wrap(err) - } - - return &certs, nil -} - -// DELETE IN: 5.1.0 -// -// This logic has been moved to KeepAliveServer. -// -// KeepAliveNode updates node keep alive information. -func (c *Client) KeepAliveNode(ctx context.Context, keepAlive types.KeepAlive) error { - return trace.BadParameter("not implemented, use StreamKeepAlives instead") -} - // KeepAliveServer not implemented: can only be called locally. func (c *Client) KeepAliveServer(ctx context.Context, keepAlive types.KeepAlive) error { return trace.BadParameter("not implemented, use StreamKeepAlives instead") } -// UpsertReverseTunnel is used by admins to create a new reverse tunnel -// to the remote proxy to bypass firewall restrictions -func (c *Client) UpsertReverseTunnel(tunnel types.ReverseTunnel) error { - data, err := services.MarshalReverseTunnel(tunnel) - if err != nil { - return trace.Wrap(err) - } - args := &upsertReverseTunnelRawReq{ - ReverseTunnel: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("reversetunnels"), args) - return trace.Wrap(err) -} - // GetReverseTunnel not implemented: can only be called locally. func (c *Client) GetReverseTunnel(name string, opts ...services.MarshalOption) (types.ReverseTunnel, error) { return nil, trace.NotImplemented(notImplementedMessage) } -// GetReverseTunnels returns the list of created reverse tunnels -func (c *Client) GetReverseTunnels(ctx context.Context, opts ...services.MarshalOption) ([]types.ReverseTunnel, error) { - out, err := c.Get(ctx, c.Endpoint("reversetunnels"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - tunnels := make([]types.ReverseTunnel, len(items)) - for i, raw := range items { - tunnel, err := services.UnmarshalReverseTunnel(raw) - if err != nil { - return nil, trace.Wrap(err) - } - tunnels[i] = tunnel - } - return tunnels, nil -} - -// DeleteReverseTunnel deletes reverse tunnel by domain name -func (c *Client) DeleteReverseTunnel(domainName string) error { - // this is to avoid confusing error in case if domain empty for example - // HTTP route will fail producing generic not found error - // instead we catch the error here - if strings.TrimSpace(domainName) == "" { - return trace.BadParameter("empty domain name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("reversetunnels", domainName)) - return trace.Wrap(err) -} - -// UpsertTunnelConnection upserts tunnel connection -func (c *Client) UpsertTunnelConnection(conn types.TunnelConnection) error { - data, err := services.MarshalTunnelConnection(conn) - if err != nil { - return trace.Wrap(err) - } - args := &upsertTunnelConnectionRawReq{ - TunnelConnection: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("tunnelconnections"), args) - return trace.Wrap(err) -} - -// GetTunnelConnections returns tunnel connections for a given cluster -func (c *Client) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) { - if clusterName == "" { - return nil, trace.BadParameter("missing cluster name parameter") - } - out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections", clusterName), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.TunnelConnection, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalTunnelConnection(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// GetAllTunnelConnections returns all tunnel connections -func (c *Client) GetAllTunnelConnections(opts ...services.MarshalOption) ([]types.TunnelConnection, error) { - out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.TunnelConnection, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalTunnelConnection(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// DeleteTunnelConnection deletes tunnel connection by name -func (c *Client) DeleteTunnelConnection(clusterName string, connName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - if connName == "" { - return trace.BadParameter("missing parameter connection name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName, connName)) - return trace.Wrap(err) -} - -// DeleteTunnelConnections deletes all tunnel connections for cluster -func (c *Client) DeleteTunnelConnections(clusterName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName)) - return trace.Wrap(err) -} - // DeleteAllTokens not implemented: can only be called locally. func (c *Client) DeleteAllTokens() error { return trace.NotImplemented(notImplementedMessage) } -// DeleteAllTunnelConnections deletes all tunnel connections -func (c *Client) DeleteAllTunnelConnections() error { - _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections")) - return trace.Wrap(err) -} - // AddUserLoginAttempt logs user login attempt func (c *Client) AddUserLoginAttempt(user string, attempt services.LoginAttempt, ttl time.Duration) error { panic("not implemented") @@ -641,103 +263,6 @@ func (c *Client) GetUserLoginAttempts(user string) ([]services.LoginAttempt, err panic("not implemented") } -// GetRemoteClusters returns a list of remote clusters -func (c *Client) GetRemoteClusters(opts ...services.MarshalOption) ([]types.RemoteCluster, error) { - out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - conns := make([]types.RemoteCluster, len(items)) - for i, raw := range items { - conn, err := services.UnmarshalRemoteCluster(raw) - if err != nil { - return nil, trace.Wrap(err) - } - conns[i] = conn - } - return conns, nil -} - -// GetRemoteCluster returns a remote cluster by name -func (c *Client) GetRemoteCluster(clusterName string) (types.RemoteCluster, error) { - if clusterName == "" { - return nil, trace.BadParameter("missing cluster name") - } - out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters", clusterName), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalRemoteCluster(out.Bytes()) -} - -// DeleteRemoteCluster deletes remote cluster by name -func (c *Client) DeleteRemoteCluster(ctx context.Context, clusterName string) error { - if clusterName == "" { - return trace.BadParameter("missing parameter cluster name") - } - - _, err := c.Delete(ctx, c.Endpoint("remoteclusters", clusterName)) - return trace.Wrap(err) -} - -// DeleteAllRemoteClusters deletes all remote clusters -func (c *Client) DeleteAllRemoteClusters() error { - _, err := c.Delete(context.TODO(), c.Endpoint("remoteclusters")) - return trace.Wrap(err) -} - -// CreateRemoteCluster creates remote cluster resource -func (c *Client) CreateRemoteCluster(rc types.RemoteCluster) error { - data, err := services.MarshalRemoteCluster(rc) - if err != nil { - return trace.Wrap(err) - } - args := &createRemoteClusterRawReq{ - RemoteCluster: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("remoteclusters"), args) - return trace.Wrap(err) -} - -// UpsertAuthServer is used by auth servers to report their presence -// to other auth servers in form of hearbeat expiring after ttl period. -func (c *Client) UpsertAuthServer(s types.Server) error { - data, err := services.MarshalServer(s) - if err != nil { - return trace.Wrap(err) - } - args := &upsertServerRawReq{ - Server: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("authservers"), args) - return trace.Wrap(err) -} - -// GetAuthServers returns the list of auth servers registered in the cluster. -func (c *Client) GetAuthServers() ([]types.Server, error) { - out, err := c.Get(context.TODO(), c.Endpoint("authservers"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - re := make([]types.Server, len(items)) - for i, raw := range items { - server, err := services.UnmarshalServer(raw, types.KindAuthServer) - if err != nil { - return nil, trace.Wrap(err) - } - re[i] = server - } - return re, nil -} - // DeleteAllAuthServers not implemented: can only be called locally. func (c *Client) DeleteAllAuthServers() error { return trace.NotImplemented(notImplementedMessage) @@ -748,332 +273,11 @@ func (c *Client) DeleteAuthServer(name string) error { return trace.NotImplemented(notImplementedMessage) } -// UpsertProxy is used by proxies to report their presence -// to other auth servers in form of heartbeat expiring after ttl period. -func (c *Client) UpsertProxy(s types.Server) error { - data, err := services.MarshalServer(s) - if err != nil { - return trace.Wrap(err) - } - args := &upsertServerRawReq{ - Server: data, - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("proxies"), args) - return trace.Wrap(err) -} - -// GetProxies returns the list of auth servers registered in the cluster. -func (c *Client) GetProxies() ([]types.Server, error) { - out, err := c.Get(context.TODO(), c.Endpoint("proxies"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var items []json.RawMessage - if err := json.Unmarshal(out.Bytes(), &items); err != nil { - return nil, trace.Wrap(err) - } - re := make([]types.Server, len(items)) - for i, raw := range items { - server, err := services.UnmarshalServer(raw, types.KindProxy) - if err != nil { - return nil, trace.Wrap(err) - } - re[i] = server - } - return re, nil -} - -// DeleteAllProxies deletes all proxies -func (c *Client) DeleteAllProxies() error { - _, err := c.Delete(context.TODO(), c.Endpoint("proxies")) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -// DeleteProxy deletes proxy by name -func (c *Client) DeleteProxy(name string) error { - if name == "" { - return trace.BadParameter("missing parameter name") - } - _, err := c.Delete(context.TODO(), c.Endpoint("proxies", name)) - if err != nil { - return trace.Wrap(err) - } - return nil -} - -// UpsertUser user updates user entry. -func (c *Client) UpsertUser(user types.User) error { - data, err := services.MarshalUser(user) - if err != nil { - return trace.Wrap(err) - } - _, err = c.PostJSON(context.TODO(), c.Endpoint("users"), &upsertUserRawReq{User: data}) - return trace.Wrap(err) -} - // CompareAndSwapUser not implemented: can only be called locally func (c *Client) CompareAndSwapUser(ctx context.Context, new, expected types.User) error { return trace.NotImplemented(notImplementedMessage) } -// ExtendWebSession creates a new web session for a user based on another -// valid web session -func (c *Client) ExtendWebSession(ctx context.Context, req WebSessionReq) (types.WebSession, error) { - out, err := c.PostJSON(ctx, c.Endpoint("users", req.User, "web", "sessions"), req) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// CreateWebSession creates a new web session for a user -func (c *Client) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) { - out, err := c.PostJSON( - ctx, - c.Endpoint("users", user, "web", "sessions"), - WebSessionReq{User: user}, - ) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// AuthenticateWebUser authenticates web user, creates and returns web session -// in case if authentication is successful -func (c *Client) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRequest) (types.WebSession, error) { - out, err := c.PostJSON( - ctx, - c.Endpoint("users", req.Username, "web", "authenticate"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH -// short lived certificates as a result -func (c *Client) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHRequest) (*SSHLoginResponse, error) { - out, err := c.PostJSON( - ctx, - c.Endpoint("users", req.Username, "ssh", "authenticate"), - req, - ) - if err != nil { - return nil, trace.Wrap(err) - } - var re SSHLoginResponse - if err := json.Unmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return &re, nil -} - -// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if -// it is valid, or error otherwise. -func (c *Client) GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error) { - out, err := c.Get( - ctx, - c.Endpoint("users", user, "web", "sessions", sessionID), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalWebSession(out.Bytes()) -} - -// DeleteWebSession deletes the web session specified with sid for the given user -func (c *Client) DeleteWebSession(ctx context.Context, user string, sid string) error { - _, err := c.Delete(ctx, c.Endpoint("users", user, "web", "sessions", sid)) - return trace.Wrap(err) -} - -// GenerateHostCert takes the public key in the Open SSH “authorized_keys“ -// plain text format, signs it using Host Certificate Authority private key and returns the -// resulting certificate. -func (c *Client) GenerateHostCert( - ctx context.Context, key []byte, hostID, nodeName string, principals []string, clusterName string, role types.SystemRole, ttl time.Duration, -) ([]byte, error) { - out, err := c.PostJSON(ctx, c.Endpoint("ca", "host", "certs"), - generateHostCertReq{ - Key: key, - HostID: hostID, - NodeName: nodeName, - Principals: principals, - ClusterName: clusterName, - Roles: types.SystemRoles{role}, - TTL: ttl, - }) - if err != nil { - return nil, trace.Wrap(err) - } - - var cert string - if err := json.Unmarshal(out.Bytes(), &cert); err != nil { - return nil, err - } - - return []byte(cert), nil -} - -// ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect -func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ - Query: q, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse *OIDCAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := OIDCAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession(rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// ValidateSAMLResponse validates response returned by SAML identity provider -func (c *Client) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), ValidateSAMLResponseReq{ - Response: re, - ConnectorID: connectorID, - }) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse *SAMLAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := SAMLAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession(rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// ValidateGithubAuthCallback validates Github auth callback returned from redirect -func (c *Client) ValidateGithubAuthCallback(ctx context.Context, q url.Values) (*GithubAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("github", "requests", "validate"), - validateGithubAuthCallbackReq{Query: q}) - if err != nil { - return nil, trace.Wrap(err) - } - var rawResponse githubAuthRawResponse - if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { - return nil, trace.Wrap(err) - } - response := GithubAuthResponse{ - Username: rawResponse.Username, - Identity: rawResponse.Identity, - Cert: rawResponse.Cert, - Req: rawResponse.Req, - TLSCert: rawResponse.TLSCert, - } - if len(rawResponse.Session) != 0 { - session, err := services.UnmarshalWebSession( - rawResponse.Session) - if err != nil { - return nil, trace.Wrap(err) - } - response.Session = session - } - response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) - for i, raw := range rawResponse.HostSigners { - ca, err := services.UnmarshalCertAuthority(raw) - if err != nil { - return nil, trace.Wrap(err) - } - response.HostSigners[i] = ca - } - return &response, nil -} - -// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded -// session stream, starting from 'offset', up to 'max' in length. The upper bound -// of 'max' is set to events.MaxChunkBytes -func (c *Client) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "stream"), url.Values{ - "offset": []string{strconv.Itoa(offsetBytes)}, - "bytes": []string{strconv.Itoa(maxBytes)}, - }) - if err != nil { - log.Error(err) - return nil, trace.Wrap(err) - } - return response.Bytes(), nil -} - -// Returns events that happen during a session sorted by time -// (oldest first). -// -// afterN allows to filter by "newer than N" value where N is the cursor ID -// of previously returned bunch (good for polling for latest) -func (c *Client) GetSessionEvents(namespace string, sid session.ID, afterN int) (retval []events.EventFields, err error) { - if namespace == "" { - return nil, trace.BadParameter(MissingNamespaceError) - } - query := make(url.Values) - if afterN > 0 { - query.Set("after", strconv.Itoa(afterN)) - } - response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "events"), query) - if err != nil { - return nil, trace.Wrap(err) - } - retval = make([]events.EventFields, 0) - if err := json.Unmarshal(response.Bytes(), &retval); err != nil { - return nil, trace.Wrap(err) - } - return retval, nil -} - // StreamSessionEvents streams all events from a given session recording. An error is returned on the first // channel if one is encountered. Otherwise the event channel is closed when the stream ends. // The event channel is not closed on error to prevent race conditions in downstream select statements. @@ -1101,120 +305,16 @@ func (c *Client) SearchSessionEvents(fromUTC, toUTC time.Time, limit int, order return events, lastKey, nil } -// GetNamespaces returns a list of namespaces -func (c *Client) GetNamespaces() ([]types.Namespace, error) { - out, err := c.Get(context.TODO(), c.Endpoint("namespaces"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - var re []types.Namespace - if err := utils.FastUnmarshal(out.Bytes(), &re); err != nil { - return nil, trace.Wrap(err) - } - return re, nil -} - -// GetNamespace returns namespace by name -func (c *Client) GetNamespace(name string) (*types.Namespace, error) { - if name == "" { - return nil, trace.BadParameter("missing namespace name") - } - out, err := c.Get(context.TODO(), c.Endpoint("namespaces", name), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - return services.UnmarshalNamespace(out.Bytes()) -} - -// UpsertNamespace upserts namespace -func (c *Client) UpsertNamespace(ns types.Namespace) error { - _, err := c.PostJSON(context.TODO(), c.Endpoint("namespaces"), upsertNamespaceReq{Namespace: ns}) - return trace.Wrap(err) -} - -// DeleteNamespace deletes namespace by name -func (c *Client) DeleteNamespace(name string) error { - _, err := c.Delete(context.TODO(), c.Endpoint("namespaces", name)) - return trace.Wrap(err) -} - // CreateRole not implemented: can only be called locally. func (c *Client) CreateRole(ctx context.Context, role types.Role) error { return trace.NotImplemented(notImplementedMessage) } -// GetClusterName returns a cluster name -func (c *Client) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { - out, err := c.Get(context.TODO(), c.Endpoint("configuration", "name"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - - cn, err := services.UnmarshalClusterName(out.Bytes()) - if err != nil { - return nil, trace.Wrap(err) - } - - return cn, err -} - -// SetClusterName sets cluster name once, will -// return Already Exists error if the name is already set -func (c *Client) SetClusterName(cn types.ClusterName) error { - data, err := services.MarshalClusterName(cn) - if err != nil { - return trace.Wrap(err) - } - - _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "name"), &setClusterNameReq{ClusterName: data}) - if err != nil { - return trace.Wrap(err) - } - - return nil -} - // UpsertClusterName not implemented: can only be called locally. func (c *Client) UpsertClusterName(cn types.ClusterName) error { return trace.NotImplemented(notImplementedMessage) } -// DeleteStaticTokens deletes static tokens -func (c *Client) DeleteStaticTokens() error { - _, err := c.Delete(context.TODO(), c.Endpoint("configuration", "static_tokens")) - return trace.Wrap(err) -} - -// GetStaticTokens returns a list of static register tokens -func (c *Client) GetStaticTokens() (types.StaticTokens, error) { - out, err := c.Get(context.TODO(), c.Endpoint("configuration", "static_tokens"), url.Values{}) - if err != nil { - return nil, trace.Wrap(err) - } - - st, err := services.UnmarshalStaticTokens(out.Bytes()) - if err != nil { - return nil, trace.Wrap(err) - } - - return st, err -} - -// SetStaticTokens sets a list of static register tokens -func (c *Client) SetStaticTokens(st types.StaticTokens) error { - data, err := services.MarshalStaticTokens(st) - if err != nil { - return trace.Wrap(err) - } - - _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "static_tokens"), &setStaticTokensReq{StaticTokens: data}) - if err != nil { - return trace.Wrap(err) - } - - return nil -} - // DeleteClusterName not implemented: can only be called locally. func (c *Client) DeleteClusterName() error { return trace.NotImplemented(notImplementedMessage) @@ -1255,31 +355,6 @@ func (c *Client) DeleteAllUsers() error { return trace.NotImplemented(notImplementedMessage) } -func (c *Client) ValidateTrustedCluster(ctx context.Context, validateRequest *ValidateTrustedClusterRequest) (*ValidateTrustedClusterResponse, error) { - validateRequestRaw, err := validateRequest.ToRaw() - if err != nil { - return nil, trace.Wrap(err) - } - - out, err := c.PostJSON(ctx, c.Endpoint("trustedclusters", "validate"), validateRequestRaw) - if err != nil { - return nil, trace.Wrap(err) - } - - var validateResponseRaw ValidateTrustedClusterResponseRaw - err = json.Unmarshal(out.Bytes(), &validateResponseRaw) - if err != nil { - return nil, trace.Wrap(err) - } - - validateResponse, err := validateResponseRaw.ToNative() - if err != nil { - return nil, trace.Wrap(err) - } - - return validateResponse, nil -} - // CreateResetPasswordToken creates reset password token func (c *Client) CreateResetPasswordToken(ctx context.Context, req CreateUserTokenRequest) (types.UserToken, error) { return c.APIClient.CreateResetPasswordToken(ctx, &proto.CreateResetPasswordTokenRequest{ @@ -1289,21 +364,6 @@ func (c *Client) CreateResetPasswordToken(ctx context.Context, req CreateUserTok }) } -// CreateBot creates a bot and associated resources. -func (c *Client) CreateBot(ctx context.Context, req *proto.CreateBotRequest) (*proto.CreateBotResponse, error) { - return c.APIClient.CreateBot(ctx, req) -} - -// DeleteBot deletes a certificate renewal bot and associated resources. -func (c *Client) DeleteBot(ctx context.Context, botName string) error { - return c.APIClient.DeleteBot(ctx, botName) -} - -// GetBotUsers fetches all bot users. -func (c *Client) GetBotUsers(ctx context.Context) ([]types.User, error) { - return c.APIClient.GetBotUsers(ctx) -} - // GetDatabaseServers returns all registered database proxy servers. func (c *Client) GetDatabaseServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.DatabaseServer, error) { return c.APIClient.GetDatabaseServers(ctx, namespace) @@ -1771,4 +831,7 @@ type ClientI interface { // still get an Okta client when calling this method, but all RPCs will return // "not implemented" errors (as per the default gRPC behavior). OktaClient() *okta.Client + + // CloneHTTPClient creates a new HTTP client with the same configuration. + CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) } diff --git a/lib/auth/http_client.go b/lib/auth/http_client.go new file mode 100644 index 0000000000000..5fca47b82dfa0 --- /dev/null +++ b/lib/auth/http_client.go @@ -0,0 +1,1116 @@ +/* +Copyright 2023 Gravitational, Inc. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +*/ + +package auth + +import ( + "context" + "crypto/tls" + "encoding/json" + "fmt" + "net/http" + "net/url" + "strconv" + "strings" + "time" + + "github.com/gravitational/roundtrip" + "github.com/gravitational/trace" + "github.com/jonboulle/clockwork" + + "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/breaker" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/constants" + apidefaults "github.com/gravitational/teleport/api/defaults" + tracehttp "github.com/gravitational/teleport/api/observability/tracing/http" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/defaults" + "github.com/gravitational/teleport/lib/events" + "github.com/gravitational/teleport/lib/httplib" + "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/session" + "github.com/gravitational/teleport/lib/utils" +) + +// HTTPClientConfig contains configuration for an HTTP client. +type HTTPClientConfig struct { + // TLS holds the TLS config for the http client. + TLS *tls.Config + // MaxIdleConns controls the maximum number of idle (keep-alive) connections across all hosts. + MaxIdleConns int + // MaxIdleConnsPerHost, if non-zero, controls the maximum idle (keep-alive) connections to keep per-host. + MaxIdleConnsPerHost int + // MaxConnsPerHost limits the total number of connections per host, including connections in the dialing, + // active, and idle states. On limit violation, dials will block. + MaxConnsPerHost int + // RequestTimeout specifies a time limit for requests made by this Client. + RequestTimeout time.Duration + // IdleConnTimeout defines the maximum amount of time before idle connections are closed. + IdleConnTimeout time.Duration + // ResponseHeaderTimeout specifies the amount of time to wait for a server's + // response headers after fully writing the request (including its body, if any). + // This time does not include the time to read the response body. + ResponseHeaderTimeout time.Duration + // Dialer is a custom dialer used to dial a server. The Dialer should + // have custom logic to provide an address to the dialer. If set, Dialer + // takes precedence over all other connection options. + Dialer client.ContextDialer + // ALPNSNIAuthDialClusterName if present the client will include ALPN SNI routing information in TLS Hello message + // allowing to dial auth service through Teleport Proxy directly without using SSH Tunnels. + ALPNSNIAuthDialClusterName string + // CircuitBreakerConfig defines how the circuit breaker should behave. + CircuitBreakerConfig breaker.Config +} + +// CheckAndSetDefaults validates and sets defaults for HTTP configuration. +func (c *HTTPClientConfig) CheckAndSetDefaults() error { + if c.TLS == nil { + return trace.BadParameter("missing TLS config") + } + + if c.Dialer == nil { + return trace.BadParameter("missing dialer") + } + + // Set the next protocol. This is needed due to the Auth Server using a + // multiplexer for protocol detection. Unless next protocol is specified + // it will attempt to upgrade to HTTP2 and at that point there is no way + // to distinguish between HTTP2/JSON or GPRC. + c.TLS.NextProtos = []string{teleport.HTTPNextProtoTLS} + + // Configure ALPN SNI direct dial TLS routing information used by ALPN SNI proxy in order to + // dial auth service without using SSH tunnels. + c.TLS = client.ConfigureALPN(c.TLS, c.ALPNSNIAuthDialClusterName) + + if c.CircuitBreakerConfig.Trip == nil || c.CircuitBreakerConfig.IsSuccessful == nil { + c.CircuitBreakerConfig = breaker.DefaultBreakerConfig(clockwork.NewRealClock()) + } + + // One or both of these timeouts should be set to ensure there is a timeout in place. + if c.RequestTimeout == 0 && c.ResponseHeaderTimeout == 0 { + c.RequestTimeout = defaults.HTTPRequestTimeout + c.ResponseHeaderTimeout = apidefaults.DefaultIOTimeout + } + + // Leaving this unset will lead to connections open forever and will cause memory leaks in a long running process. + if c.IdleConnTimeout == 0 { + c.IdleConnTimeout = defaults.HTTPIdleTimeout + } + + // Increase the size of the connection pool. This substantially improves the + // performance of Teleport under load as it reduces the number of TLS + // handshakes performed. + if c.MaxIdleConns == 0 { + c.MaxIdleConns = defaults.HTTPMaxIdleConns + } + if c.MaxIdleConnsPerHost == 0 { + c.MaxIdleConnsPerHost = defaults.HTTPMaxIdleConnsPerHost + } + + // Limit the total number of connections to the Auth Server. Some hosts allow a low + // number of connections per process (ulimit) to a host. This is a problem for + // enhanced session recording auditing which emits so many events to the + // Audit Log (using the Auth Client) that the connection pool often does not + // have a free connection to return, so just opens a new one. This quickly + // leads to hitting the OS limit and the client returning out of file + // descriptors error. + if c.MaxConnsPerHost == 0 { + c.MaxConnsPerHost = defaults.HTTPMaxConnsPerHost + } + + return nil +} + +// Clone creates a new client with the same configuration. +func (c *HTTPClientConfig) Clone() *HTTPClientConfig { + return &HTTPClientConfig{ + TLS: c.TLS.Clone(), + MaxIdleConns: c.MaxIdleConns, + MaxIdleConnsPerHost: c.MaxIdleConnsPerHost, + MaxConnsPerHost: c.MaxConnsPerHost, + RequestTimeout: c.RequestTimeout, + IdleConnTimeout: c.IdleConnTimeout, + ResponseHeaderTimeout: c.ResponseHeaderTimeout, + Dialer: c.Dialer, + ALPNSNIAuthDialClusterName: c.ALPNSNIAuthDialClusterName, + CircuitBreakerConfig: c.CircuitBreakerConfig, + } +} + +// HTTPClient is a teleport HTTP API client. +type HTTPClient struct { + *roundtrip.Client + // cfg is the http client configuration. + cfg *HTTPClientConfig +} + +// NewHTTPClient creates a new HTTP client with TLS authentication and the given dialer. +func NewHTTPClient(cfg *HTTPClientConfig, params ...roundtrip.ClientParam) (*HTTPClient, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, err + } + + transport := &http.Transport{ + DialContext: cfg.Dialer.DialContext, + ResponseHeaderTimeout: cfg.ResponseHeaderTimeout, + TLSClientConfig: cfg.TLS, + MaxIdleConns: cfg.MaxIdleConns, + MaxIdleConnsPerHost: cfg.MaxIdleConnsPerHost, + MaxConnsPerHost: cfg.MaxConnsPerHost, + IdleConnTimeout: cfg.IdleConnTimeout, + } + + roundtripClient, err := newRoundtripClient(cfg, transport) + if err != nil { + return nil, trace.Wrap(err) + } + + return &HTTPClient{ + cfg: cfg, + Client: roundtripClient, + }, nil +} + +func newRoundtripClient(cfg *HTTPClientConfig, transport *http.Transport, params ...roundtrip.ClientParam) (*roundtrip.Client, error) { + cb, err := breaker.New(cfg.CircuitBreakerConfig) + if err != nil { + return nil, trace.Wrap(err) + } + + clientParams := append( + []roundtrip.ClientParam{ + roundtrip.HTTPClient(&http.Client{ + Timeout: cfg.RequestTimeout, + Transport: tracehttp.NewTransport(breaker.NewRoundTripper(cb, transport)), + }), + roundtrip.SanitizerEnabled(true), + }, + params..., + ) + + // Since the client uses a custom dialer and SNI is used for TLS handshake, the address + // used here is arbitrary as it just needs to be set to pass http request validation. + roundtripClient, err := roundtrip.NewClient("https://"+constants.APIDomain, CurrentVersion, clientParams...) + if err != nil { + return nil, trace.Wrap(err) + } + + return roundtripClient, nil +} + +// CloneHTTPClient creates a new HTTP client with the same configuration. +func (c *HTTPClient) CloneHTTPClient(params ...roundtrip.ClientParam) (*HTTPClient, error) { + cfg := c.cfg.Clone() + + // We copy the transport which may have had roundtrip.ClientParams applied on initial creation. + transport, err := c.getTransport() + if err != nil { + return nil, trace.Wrap(err) + } + + roundtripClient, err := newRoundtripClient(c.cfg, transport, params...) + if err != nil { + return nil, trace.Wrap(err) + } + + return &HTTPClient{ + Client: roundtripClient, + cfg: cfg, + }, nil +} + +// ClientParamRequestTimeout sets request timeout of the HTTP transport used by the client. +func ClientParamTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + c.HTTPClient().Timeout = timeout + return nil + } +} + +// ClientParamResponseHeaderTimeout sets response header timeout of the HTTP transport used by the client. +func ClientParamResponseHeaderTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + if t, err := getHTTPTransport(c); err == nil { + t.ResponseHeaderTimeout = timeout + } + return nil + } +} + +// ClientParamIdleConnTimeout sets idle connection header timeout of the HTTP transport used by the client. +func ClientParamIdleConnTimeout(timeout time.Duration) roundtrip.ClientParam { + return func(c *roundtrip.Client) error { + if t, err := getHTTPTransport(c); err == nil { + t.IdleConnTimeout = timeout + } + return nil + } +} + +// Close closes the HTTP client connection to the auth server. +func (c *HTTPClient) Close() { + c.Client.HTTPClient().CloseIdleConnections() +} + +// TLSConfig returns the HTTP client's TLS config. +func (c *HTTPClient) TLSConfig() *tls.Config { + return c.cfg.TLS +} + +// GetTransport returns the HTTP client's transport. +func (c *HTTPClient) getTransport() (*http.Transport, error) { + return getHTTPTransport(c.Client) +} + +func getHTTPTransport(c *roundtrip.Client) (*http.Transport, error) { + type wrapper interface { + Unwrap() http.RoundTripper + } + + transport := c.HTTPClient().Transport + for { + switch t := transport.(type) { + case wrapper: + transport = t.Unwrap() + case *http.Transport: + return t, nil + default: + return nil, trace.BadParameter("unexpected transport type %T", t) + } + } +} + +// PostJSON is a generic method that issues http POST request to the server +func (c *HTTPClient) PostJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PostJSON(ctx, endpoint, val)) +} + +// PutJSON is a generic method that issues http PUT request to the server +func (c *HTTPClient) PutJSON(ctx context.Context, endpoint string, val interface{}) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PutJSON(ctx, endpoint, val)) +} + +// PostForm is a generic method that issues http POST request to the server +func (c *HTTPClient) PostForm(ctx context.Context, endpoint string, vals url.Values, files ...roundtrip.File) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.PostForm(ctx, endpoint, vals, files...)) +} + +// Get issues http GET request to the server +func (c *HTTPClient) Get(ctx context.Context, u string, params url.Values) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.Get(ctx, u, params)) +} + +// Delete issues http Delete Request to the server +func (c *HTTPClient) Delete(ctx context.Context, u string) (*roundtrip.Response, error) { + return httplib.ConvertResponse(c.Client.Delete(ctx, u)) +} + +// ProcessKubeCSR processes CSR request against Kubernetes CA, returns +// signed certificate if successful. +func (c *HTTPClient) ProcessKubeCSR(req KubeCSR) (*KubeCSRResponse, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + out, err := c.PostJSON(context.TODO(), c.Endpoint("kube", "csr"), req) + if err != nil { + return nil, trace.Wrap(err) + } + var re KubeCSRResponse + if err := json.Unmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return &re, nil +} + +// RotateCertAuthority starts or restarts certificate authority rotation process. +func (c *HTTPClient) RotateCertAuthority(ctx context.Context, req RotateRequest) error { + _, err := c.PostJSON(ctx, c.Endpoint("authorities", string(req.Type), "rotate"), req) + return trace.Wrap(err) +} + +// RotateExternalCertAuthority rotates external certificate authority, +// this method is used to update only public keys and certificates of the +// the certificate authorities of trusted clusters. +func (c *HTTPClient) RotateExternalCertAuthority(ctx context.Context, ca types.CertAuthority) error { + if err := services.ValidateCertAuthority(ca); err != nil { + return trace.Wrap(err) + } + data, err := services.MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType()), "rotate", "external"), + &rotateExternalCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + +// UpsertCertAuthority updates or inserts new cert authority +// DELETE IN 14.0.0 +func (c *HTTPClient) UpsertCertAuthority(ctx context.Context, ca types.CertAuthority) error { + data, err := services.MarshalCertAuthority(ca) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(ctx, c.Endpoint("authorities", string(ca.GetType())), + &upsertCertAuthorityRawReq{CA: data}) + return trace.Wrap(err) +} + +// GetCertAuthorities returns a list of certificate authorities +// DELETE IN 14.0.0 +func (c *HTTPClient) GetCertAuthorities(ctx context.Context, caType types.CertAuthType, loadKeys bool, opts ...services.MarshalOption) ([]types.CertAuthority, error) { + resp, err := c.Get(ctx, c.Endpoint("authorities", string(caType)), url.Values{ + "load_keys": []string{fmt.Sprintf("%t", loadKeys)}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(resp.Bytes(), &items); err != nil { + return nil, err + } + cas := make([]types.CertAuthority, 0, len(items)) + for _, raw := range items { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + + cas = append(cas, ca) + } + + return cas, nil +} + +// GetCertAuthority returns certificate authority by given id. Parameter loadSigningKeys +// controls if signing keys are loaded +// DELETE IN 14.0.0 +func (c *HTTPClient) GetCertAuthority(ctx context.Context, id types.CertAuthID, loadSigningKeys bool) (types.CertAuthority, error) { + out, err := c.Get(ctx, c.Endpoint("authorities", string(id.Type), id.DomainName), url.Values{ + "load_keys": []string{fmt.Sprintf("%t", loadSigningKeys)}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + ca, err := services.UnmarshalCertAuthority(out.Bytes()) + return ca, trace.Wrap(err) +} + +// DeleteCertAuthority deletes cert authority by ID +// DELETE IN 14.0.0 +func (c *HTTPClient) DeleteCertAuthority(ctx context.Context, id types.CertAuthID) error { + _, err := c.Delete(ctx, c.Endpoint("authorities", string(id.Type), id.DomainName)) + return trace.Wrap(err) +} + +// RegisterUsingToken calls the auth service API to register a new node using a registration token +// which was previously issued via CreateToken/UpsertToken. +func (c *HTTPClient) RegisterUsingToken(ctx context.Context, req *types.RegisterUsingTokenRequest) (*proto.Certs, error) { + if err := req.CheckAndSetDefaults(); err != nil { + return nil, trace.Wrap(err) + } + out, err := c.PostJSON(ctx, c.Endpoint("tokens", "register"), req) + if err != nil { + return nil, trace.Wrap(err) + } + + var certs proto.Certs + if err := json.Unmarshal(out.Bytes(), &certs); err != nil { + return nil, trace.Wrap(err) + } + + return &certs, nil +} + +// UpsertReverseTunnel is used by admins to create a new reverse tunnel +// to the remote proxy to bypass firewall restrictions +func (c *HTTPClient) UpsertReverseTunnel(tunnel types.ReverseTunnel) error { + data, err := services.MarshalReverseTunnel(tunnel) + if err != nil { + return trace.Wrap(err) + } + args := &upsertReverseTunnelRawReq{ + ReverseTunnel: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("reversetunnels"), args) + return trace.Wrap(err) +} + +// GetReverseTunnels returns the list of created reverse tunnels +func (c *HTTPClient) GetReverseTunnels(ctx context.Context, opts ...services.MarshalOption) ([]types.ReverseTunnel, error) { + out, err := c.Get(ctx, c.Endpoint("reversetunnels"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + tunnels := make([]types.ReverseTunnel, len(items)) + for i, raw := range items { + tunnel, err := services.UnmarshalReverseTunnel(raw) + if err != nil { + return nil, trace.Wrap(err) + } + tunnels[i] = tunnel + } + return tunnels, nil +} + +// DeleteReverseTunnel deletes reverse tunnel by domain name +func (c *HTTPClient) DeleteReverseTunnel(domainName string) error { + // this is to avoid confusing error in case if domain empty for example + // HTTP route will fail producing generic not found error + // instead we catch the error here + if strings.TrimSpace(domainName) == "" { + return trace.BadParameter("empty domain name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("reversetunnels", domainName)) + return trace.Wrap(err) +} + +// UpsertTunnelConnection upserts tunnel connection +func (c *HTTPClient) UpsertTunnelConnection(conn types.TunnelConnection) error { + data, err := services.MarshalTunnelConnection(conn) + if err != nil { + return trace.Wrap(err) + } + args := &upsertTunnelConnectionRawReq{ + TunnelConnection: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("tunnelconnections"), args) + return trace.Wrap(err) +} + +// GetTunnelConnections returns tunnel connections for a given cluster +func (c *HTTPClient) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name parameter") + } + out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections", clusterName), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// GetAllTunnelConnections returns all tunnel connections +func (c *HTTPClient) GetAllTunnelConnections(opts ...services.MarshalOption) ([]types.TunnelConnection, error) { + out, err := c.Get(context.TODO(), c.Endpoint("tunnelconnections"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.TunnelConnection, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalTunnelConnection(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// DeleteTunnelConnection deletes tunnel connection by name +func (c *HTTPClient) DeleteTunnelConnection(clusterName string, connName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + if connName == "" { + return trace.BadParameter("missing parameter connection name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName, connName)) + return trace.Wrap(err) +} + +// DeleteTunnelConnections deletes all tunnel connections for cluster +func (c *HTTPClient) DeleteTunnelConnections(clusterName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections", clusterName)) + return trace.Wrap(err) +} + +// DeleteAllTunnelConnections deletes all tunnel connections +func (c *HTTPClient) DeleteAllTunnelConnections() error { + _, err := c.Delete(context.TODO(), c.Endpoint("tunnelconnections")) + return trace.Wrap(err) +} + +// GetRemoteClusters returns a list of remote clusters +func (c *HTTPClient) GetRemoteClusters(opts ...services.MarshalOption) ([]types.RemoteCluster, error) { + out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + conns := make([]types.RemoteCluster, len(items)) + for i, raw := range items { + conn, err := services.UnmarshalRemoteCluster(raw) + if err != nil { + return nil, trace.Wrap(err) + } + conns[i] = conn + } + return conns, nil +} + +// GetRemoteCluster returns a remote cluster by name +func (c *HTTPClient) GetRemoteCluster(clusterName string) (types.RemoteCluster, error) { + if clusterName == "" { + return nil, trace.BadParameter("missing cluster name") + } + out, err := c.Get(context.TODO(), c.Endpoint("remoteclusters", clusterName), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalRemoteCluster(out.Bytes()) +} + +// DeleteRemoteCluster deletes remote cluster by name +func (c *HTTPClient) DeleteRemoteCluster(ctx context.Context, clusterName string) error { + if clusterName == "" { + return trace.BadParameter("missing parameter cluster name") + } + _, err := c.Delete(ctx, c.Endpoint("remoteclusters", clusterName)) + return trace.Wrap(err) +} + +// DeleteAllRemoteClusters deletes all remote clusters +func (c *HTTPClient) DeleteAllRemoteClusters() error { + _, err := c.Delete(context.TODO(), c.Endpoint("remoteclusters")) + return trace.Wrap(err) +} + +// CreateRemoteCluster creates remote cluster resource +func (c *HTTPClient) CreateRemoteCluster(rc types.RemoteCluster) error { + data, err := services.MarshalRemoteCluster(rc) + if err != nil { + return trace.Wrap(err) + } + args := &createRemoteClusterRawReq{ + RemoteCluster: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("remoteclusters"), args) + return trace.Wrap(err) +} + +// UpsertAuthServer is used by auth servers to report their presence +// to other auth servers in form of hearbeat expiring after ttl period. +func (c *HTTPClient) UpsertAuthServer(s types.Server) error { + data, err := services.MarshalServer(s) + if err != nil { + return trace.Wrap(err) + } + args := &upsertServerRawReq{ + Server: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("authservers"), args) + return trace.Wrap(err) +} + +// GetAuthServers returns the list of auth servers registered in the cluster. +func (c *HTTPClient) GetAuthServers() ([]types.Server, error) { + out, err := c.Get(context.TODO(), c.Endpoint("authservers"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + re := make([]types.Server, len(items)) + for i, raw := range items { + server, err := services.UnmarshalServer(raw, types.KindAuthServer) + if err != nil { + return nil, trace.Wrap(err) + } + re[i] = server + } + return re, nil +} + +// UpsertProxy is used by proxies to report their presence +// to other auth servers in form of heartbeat expiring after ttl period. +func (c *HTTPClient) UpsertProxy(s types.Server) error { + data, err := services.MarshalServer(s) + if err != nil { + return trace.Wrap(err) + } + args := &upsertServerRawReq{ + Server: data, + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("proxies"), args) + return trace.Wrap(err) +} + +// GetProxies returns the list of auth servers registered in the cluster. +func (c *HTTPClient) GetProxies() ([]types.Server, error) { + out, err := c.Get(context.TODO(), c.Endpoint("proxies"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var items []json.RawMessage + if err := json.Unmarshal(out.Bytes(), &items); err != nil { + return nil, trace.Wrap(err) + } + re := make([]types.Server, len(items)) + for i, raw := range items { + server, err := services.UnmarshalServer(raw, types.KindProxy) + if err != nil { + return nil, trace.Wrap(err) + } + re[i] = server + } + return re, nil +} + +// DeleteAllProxies deletes all proxies +func (c *HTTPClient) DeleteAllProxies() error { + _, err := c.Delete(context.TODO(), c.Endpoint("proxies")) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// DeleteProxy deletes proxy by name +func (c *HTTPClient) DeleteProxy(name string) error { + if name == "" { + return trace.BadParameter("missing parameter name") + } + _, err := c.Delete(context.TODO(), c.Endpoint("proxies", name)) + if err != nil { + return trace.Wrap(err) + } + return nil +} + +// UpsertUser user updates user entry. +func (c *HTTPClient) UpsertUser(user types.User) error { + data, err := services.MarshalUser(user) + if err != nil { + return trace.Wrap(err) + } + _, err = c.PostJSON(context.TODO(), c.Endpoint("users"), &upsertUserRawReq{User: data}) + return trace.Wrap(err) +} + +// ExtendWebSession creates a new web session for a user based on another +// valid web session +func (c *HTTPClient) ExtendWebSession(ctx context.Context, req WebSessionReq) (types.WebSession, error) { + out, err := c.PostJSON(ctx, c.Endpoint("users", req.User, "web", "sessions"), req) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// CreateWebSession creates a new web session for a user +func (c *HTTPClient) CreateWebSession(ctx context.Context, user string) (types.WebSession, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", user, "web", "sessions"), + WebSessionReq{User: user}, + ) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// AuthenticateWebUser authenticates web user, creates and returns web session +// in case if authentication is successful +func (c *HTTPClient) AuthenticateWebUser(ctx context.Context, req AuthenticateUserRequest) (types.WebSession, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", req.Username, "web", "authenticate"), + req, + ) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// AuthenticateSSHUser authenticates SSH console user, creates and returns a pair of signed TLS and SSH +// short lived certificates as a result +func (c *HTTPClient) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHRequest) (*SSHLoginResponse, error) { + out, err := c.PostJSON( + ctx, + c.Endpoint("users", req.Username, "ssh", "authenticate"), + req, + ) + if err != nil { + return nil, trace.Wrap(err) + } + var re SSHLoginResponse + if err := json.Unmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return &re, nil +} + +// GetWebSessionInfo checks if a web sesion is valid, returns session id in case if +// it is valid, or error otherwise. +func (c *HTTPClient) GetWebSessionInfo(ctx context.Context, user, sessionID string) (types.WebSession, error) { + out, err := c.Get( + ctx, + c.Endpoint("users", user, "web", "sessions", sessionID), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalWebSession(out.Bytes()) +} + +// DeleteWebSession deletes the web session specified with sid for the given user +func (c *HTTPClient) DeleteWebSession(ctx context.Context, user string, sid string) error { + _, err := c.Delete(ctx, c.Endpoint("users", user, "web", "sessions", sid)) + return trace.Wrap(err) +} + +// GenerateHostCert takes the public key in the Open SSH “authorized_keys“ +// plain text format, signs it using Host Certificate Authority private key and returns the +// resulting certificate. +func (c *HTTPClient) GenerateHostCert( + ctx context.Context, key []byte, hostID, nodeName string, principals []string, clusterName string, role types.SystemRole, ttl time.Duration, +) ([]byte, error) { + out, err := c.PostJSON(ctx, c.Endpoint("ca", "host", "certs"), + generateHostCertReq{ + Key: key, + HostID: hostID, + NodeName: nodeName, + Principals: principals, + ClusterName: clusterName, + Roles: types.SystemRoles{role}, + TTL: ttl, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + var cert string + if err := json.Unmarshal(out.Bytes(), &cert); err != nil { + return nil, err + } + + return []byte(cert), nil +} + +// ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect +func (c *HTTPClient) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ + Query: q, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse *OIDCAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := OIDCAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession(rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// ValidateSAMLResponse validates response returned by SAML identity provider +func (c *HTTPClient) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), ValidateSAMLResponseReq{ + Response: re, + ConnectorID: connectorID, + }) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse *SAMLAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := SAMLAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession(rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// ValidateGithubAuthCallback validates Github auth callback returned from redirect +func (c *HTTPClient) ValidateGithubAuthCallback(ctx context.Context, q url.Values) (*GithubAuthResponse, error) { + out, err := c.PostJSON(ctx, c.Endpoint("github", "requests", "validate"), + validateGithubAuthCallbackReq{Query: q}) + if err != nil { + return nil, trace.Wrap(err) + } + var rawResponse githubAuthRawResponse + if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { + return nil, trace.Wrap(err) + } + response := GithubAuthResponse{ + Username: rawResponse.Username, + Identity: rawResponse.Identity, + Cert: rawResponse.Cert, + Req: rawResponse.Req, + TLSCert: rawResponse.TLSCert, + } + if len(rawResponse.Session) != 0 { + session, err := services.UnmarshalWebSession( + rawResponse.Session) + if err != nil { + return nil, trace.Wrap(err) + } + response.Session = session + } + response.HostSigners = make([]types.CertAuthority, len(rawResponse.HostSigners)) + for i, raw := range rawResponse.HostSigners { + ca, err := services.UnmarshalCertAuthority(raw) + if err != nil { + return nil, trace.Wrap(err) + } + response.HostSigners[i] = ca + } + return &response, nil +} + +// GetSessionChunk allows clients to receive a byte array (chunk) from a recorded +// session stream, starting from 'offset', up to 'max' in length. The upper bound +// of 'max' is set to events.MaxChunkBytes +func (c *HTTPClient) GetSessionChunk(namespace string, sid session.ID, offsetBytes, maxBytes int) ([]byte, error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "stream"), url.Values{ + "offset": []string{strconv.Itoa(offsetBytes)}, + "bytes": []string{strconv.Itoa(maxBytes)}, + }) + if err != nil { + log.Error(err) + return nil, trace.Wrap(err) + } + return response.Bytes(), nil +} + +// Returns events that happen during a session sorted by time +// (oldest first). +// +// afterN allows to filter by "newer than N" value where N is the cursor ID +// of previously returned bunch (good for polling for latest) +func (c *HTTPClient) GetSessionEvents(namespace string, sid session.ID, afterN int) (retval []events.EventFields, err error) { + if namespace == "" { + return nil, trace.BadParameter(MissingNamespaceError) + } + query := make(url.Values) + if afterN > 0 { + query.Set("after", strconv.Itoa(afterN)) + } + response, err := c.Get(context.TODO(), c.Endpoint("namespaces", namespace, "sessions", string(sid), "events"), query) + if err != nil { + return nil, trace.Wrap(err) + } + retval = make([]events.EventFields, 0) + if err := json.Unmarshal(response.Bytes(), &retval); err != nil { + return nil, trace.Wrap(err) + } + return retval, nil +} + +// GetNamespaces returns a list of namespaces +func (c *HTTPClient) GetNamespaces() ([]types.Namespace, error) { + out, err := c.Get(context.TODO(), c.Endpoint("namespaces"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + var re []types.Namespace + if err := utils.FastUnmarshal(out.Bytes(), &re); err != nil { + return nil, trace.Wrap(err) + } + return re, nil +} + +// GetNamespace returns namespace by name +func (c *HTTPClient) GetNamespace(name string) (*types.Namespace, error) { + if name == "" { + return nil, trace.BadParameter("missing namespace name") + } + out, err := c.Get(context.TODO(), c.Endpoint("namespaces", name), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + return services.UnmarshalNamespace(out.Bytes()) +} + +// UpsertNamespace upserts namespace +func (c *HTTPClient) UpsertNamespace(ns types.Namespace) error { + _, err := c.PostJSON(context.TODO(), c.Endpoint("namespaces"), upsertNamespaceReq{Namespace: ns}) + return trace.Wrap(err) +} + +// DeleteNamespace deletes namespace by name +func (c *HTTPClient) DeleteNamespace(name string) error { + _, err := c.Delete(context.TODO(), c.Endpoint("namespaces", name)) + return trace.Wrap(err) +} + +// GetClusterName returns a cluster name +func (c *HTTPClient) GetClusterName(opts ...services.MarshalOption) (types.ClusterName, error) { + out, err := c.Get(context.TODO(), c.Endpoint("configuration", "name"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + + cn, err := services.UnmarshalClusterName(out.Bytes()) + if err != nil { + return nil, trace.Wrap(err) + } + + return cn, err +} + +// SetClusterName sets cluster name once, will +// return Already Exists error if the name is already set +func (c *HTTPClient) SetClusterName(cn types.ClusterName) error { + data, err := services.MarshalClusterName(cn) + if err != nil { + return trace.Wrap(err) + } + + _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "name"), &setClusterNameReq{ClusterName: data}) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +// DeleteStaticTokens deletes static tokens +func (c *HTTPClient) DeleteStaticTokens() error { + _, err := c.Delete(context.TODO(), c.Endpoint("configuration", "static_tokens")) + return trace.Wrap(err) +} + +// GetStaticTokens returns a list of static register tokens +func (c *HTTPClient) GetStaticTokens() (types.StaticTokens, error) { + out, err := c.Get(context.TODO(), c.Endpoint("configuration", "static_tokens"), url.Values{}) + if err != nil { + return nil, trace.Wrap(err) + } + + st, err := services.UnmarshalStaticTokens(out.Bytes()) + if err != nil { + return nil, trace.Wrap(err) + } + + return st, err +} + +// SetStaticTokens sets a list of static register tokens +func (c *HTTPClient) SetStaticTokens(st types.StaticTokens) error { + data, err := services.MarshalStaticTokens(st) + if err != nil { + return trace.Wrap(err) + } + + _, err = c.PostJSON(context.TODO(), c.Endpoint("configuration", "static_tokens"), &setStaticTokensReq{StaticTokens: data}) + if err != nil { + return trace.Wrap(err) + } + + return nil +} + +func (c *HTTPClient) ValidateTrustedCluster(ctx context.Context, validateRequest *ValidateTrustedClusterRequest) (*ValidateTrustedClusterResponse, error) { + validateRequestRaw, err := validateRequest.ToRaw() + if err != nil { + return nil, trace.Wrap(err) + } + + out, err := c.PostJSON(ctx, c.Endpoint("trustedclusters", "validate"), validateRequestRaw) + if err != nil { + return nil, trace.Wrap(err) + } + + var validateResponseRaw ValidateTrustedClusterResponseRaw + err = json.Unmarshal(out.Bytes(), &validateResponseRaw) + if err != nil { + return nil, trace.Wrap(err) + } + + validateResponse, err := validateResponseRaw.ToNative() + if err != nil { + return nil, trace.Wrap(err) + } + + return validateResponse, nil +} diff --git a/lib/auth/methods.go b/lib/auth/methods.go index 9cc024ca0ee79..a578ec6594890 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -17,6 +17,7 @@ limitations under the License. package auth import ( + "bytes" "context" "errors" "time" @@ -30,6 +31,7 @@ import ( apievents "github.com/gravitational/teleport/api/types/events" "github.com/gravitational/teleport/api/utils/keys" wanlib "github.com/gravitational/teleport/lib/auth/webauthn" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" @@ -67,9 +69,12 @@ type AuthenticateUserRequest struct { } // ForwardedClientMetadata can be used by the proxy web API to forward information about -// the client to the auth service for logging purposes. +// the client to the auth service. type ForwardedClientMetadata struct { - UserAgent string `json:"user_agent,omitempty"` + UserAgent string `json:"user_agent,omitempty"` + // RemoteAddr is the IP address of the end user. This IP address is derived + // either from a direct client connection, or from a PROXY protocol header + // if the connection is forwarded through a load balancer. RemoteAddr string `json:"remote_addr,omitempty"` } @@ -79,7 +84,7 @@ func (a *AuthenticateUserRequest) CheckAndSetDefaults() error { case a.Username == "" && a.Webauthn != nil: // OK, passwordless. case a.Username == "": return trace.BadParameter("missing parameter 'username'") - case a.Pass == nil && a.Webauthn == nil && a.OTP == nil && a.Session == nil: + case a.Pass == nil && a.Webauthn == nil && a.OTP == nil && a.Session == nil && a.HeadlessAuthenticationID == "": return trace.BadParameter("at least one authentication method is required") } return nil @@ -166,6 +171,9 @@ var ( // invalidUserpass2FError is the error for when either the provided username, // password, or second factor is incorrect. invalidUserPass2FError = trace.AccessDenied("invalid username, password or second factor") + // invalidHeadlessAuthenticationError is the generic error returned for failed headless + // authentication attempts. + invalidHeadlessAuthenticationError = trace.AccessDenied("invalid Headless authentication") ) // IsInvalidLocalCredentialError checks if an error resulted from an incorrect username, @@ -216,6 +224,18 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque return res.mfaDev, nil } authErr = invalidUserPass2FError + case req.HeadlessAuthenticationID != "": + // handle authentication before the user lock to prevent locking out users + // due to timed-out/canceled headless authentication attempts. + mfaDevice, err := s.authenticateHeadless(ctx, req) + if err != nil { + log.Debugf("Headless Authentication for user %q failed while waiting for approval: %v", user, err) + return nil, "", trace.Wrap(invalidHeadlessAuthenticationError) + } + authenticateFn = func() (*types.MFADevice, error) { + return mfaDevice, nil + } + authErr = invalidHeadlessAuthenticationError } if authenticateFn != nil { var dev *types.MFADevice @@ -234,8 +254,8 @@ func (s *Server) authenticateUser(ctx context.Context, req AuthenticateUserReque return nil, "", trace.Wrap(authErr) case dev == nil: log.Debugf( - "MFA authentication returned nil device (Webauthn = %v, TOTP = %v): %v.", - req.Webauthn != nil, req.OTP != nil, err) + "MFA authentication returned nil device (Webauthn = %v, TOTP = %v, Headless = %v): %v.", + req.Webauthn != nil, req.OTP != nil, req.HeadlessAuthenticationID != "", err) return nil, "", trace.Wrap(authErr) default: return dev, user, nil @@ -312,6 +332,83 @@ func (s *Server) authenticatePasswordless(ctx context.Context, req AuthenticateU return dev, user, nil } +func (s *Server) authenticateHeadless(ctx context.Context, req AuthenticateUserRequest) (mfa *types.MFADevice, err error) { + // Delete the headless authentication upon failure. + defer func() { + if err != nil { + if err := s.DeleteHeadlessAuthentication(s.CloseContext(), req.HeadlessAuthenticationID); err != nil && !trace.IsNotFound(err) { + log.Debugf("Failed to delete headless authentication: %v", err) + } + } + }() + + // this authentication requires two client callbacks to create a headless authentication + // stub and approve/deny the headless authentication, so we use a standard callback timeout. + ctx, cancel := context.WithTimeout(ctx, defaults.CallbackTimeout) + defer cancel() + + headlessAuthn := &types.HeadlessAuthentication{ + ResourceHeader: types.ResourceHeader{ + Metadata: types.Metadata{ + Name: req.HeadlessAuthenticationID, + }, + }, + User: req.Username, + PublicKey: req.PublicKey, + ClientIpAddress: req.ClientMetadata.RemoteAddr, + } + + // Headless Authentication should expire when the callback expires. + headlessAuthn.SetExpiry(s.clock.Now().Add(defaults.CallbackTimeout)) + + if err := services.ValidateHeadlessAuthentication(headlessAuthn); err != nil { + return nil, trace.Wrap(err) + } + + // Wait for a headless authenticated stub to be inserted by an authenticated + // call to GetHeadlessAuthentication. We do this to avoid immediately inserting + // backend items from an unauthenticated endpoint. + headlessAuthnStub, err := s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + // Only headless authentication stub can be inserted without the standard validation. + if services.ValidateHeadlessAuthentication(ha) == nil { + return false, trace.AlreadyExists("headless auth request already exists") + } + return true, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // Update headless authentication with login details. + if _, err := s.CompareAndSwapHeadlessAuthentication(ctx, headlessAuthnStub, headlessAuthn); err != nil { + return nil, trace.Wrap(err) + } + + // Wait for the request to be approved/denied. + headlessAuthn, err = s.headlessAuthenticationWatcher.Wait(ctx, req.HeadlessAuthenticationID, func(ha *types.HeadlessAuthentication) (bool, error) { + switch ha.State { + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_APPROVED: + if ha.MfaDevice == nil { + return false, trace.AccessDenied("expected mfa approval for headless authentication approval") + } + return true, nil + case types.HeadlessAuthenticationState_HEADLESS_AUTHENTICATION_STATE_DENIED: + return false, trace.AccessDenied("headless authentication denied") + } + return false, nil + }) + if err != nil { + return nil, trace.Wrap(err) + } + + // Verify that the headless authentication has not been tampered with. + if headlessAuthn.User != req.Username { + return nil, trace.AccessDenied("user mismatch") + } + + return headlessAuthn.MfaDevice, nil +} + // AuthenticateWebUser authenticates web user, creates and returns a web session // if authentication is successful. In case the existing session ID is used to authenticate, // returns the existing session instead of creating a new one @@ -509,7 +606,7 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq return nil, trace.BadParameter("source IP pinning is enabled but client IP is unknown") } - certs, err := s.generateUserCert(certRequest{ + certReq := certRequest{ user: user, ttl: req.TTL, publicKey: req.PublicKey, @@ -520,7 +617,22 @@ func (s *Server) AuthenticateSSHUser(ctx context.Context, req AuthenticateSSHReq kubernetesCluster: req.KubernetesCluster, loginIP: clientIP, attestationStatement: req.AttestationStatement, - }) + } + + // For headless authentication, a short-lived mfa-verified cert should be generated. + if req.HeadlessAuthenticationID != "" { + ha, err := s.GetHeadlessAuthentication(ctx, req.HeadlessAuthenticationID) + if err != nil { + return nil, trace.Wrap(err) + } + if !bytes.Equal(req.PublicKey, ha.PublicKey) { + return nil, trace.AccessDenied("headless authentication public key mismatch") + } + certReq.mfaVerified = ha.MfaDevice.Metadata.Name + certReq.ttl = time.Minute + } + + certs, err := s.generateUserCert(certReq) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/service/connect.go b/lib/service/connect.go index 0c291957b1146..09be003fa9895 100644 --- a/lib/service/connect.go +++ b/lib/service/connect.go @@ -1198,7 +1198,10 @@ func (process *TeleportProcess) newClientThroughTunnel(authServers []utils.NetAd func (process *TeleportProcess) newClientDirect(authServers []utils.NetAddr, tlsConfig *tls.Config, role types.SystemRole) (*auth.Client, error) { var cltParams []roundtrip.ClientParam if process.Config.ClientTimeout != 0 { - cltParams = []roundtrip.ClientParam{auth.ClientTimeout(process.Config.ClientTimeout)} + cltParams = []roundtrip.ClientParam{ + auth.ClientParamIdleConnTimeout(process.Config.ClientTimeout), + auth.ClientParamResponseHeaderTimeout(process.Config.ClientTimeout), + } } var dialOpts []grpc.DialOption diff --git a/lib/services/identity.go b/lib/services/identity.go index 2c1661038cd45..ceb280c0d3b75 100644 --- a/lib/services/identity.go +++ b/lib/services/identity.go @@ -257,6 +257,16 @@ type Identity interface { // GetKeyAttestationData gets a verified public key attestation response. GetKeyAttestationData(ctx context.Context, publicKey crypto.PublicKey) (*keys.AttestationData, error) + // CreateHeadlessAuthenticationStub creates a headless authentication stub. + CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) + + // CompareAndSwapHeadlessAuthentication performs a compare + // and swap replacement on a headless authentication resource. + CompareAndSwapHeadlessAuthentication(ctx context.Context, old, new *types.HeadlessAuthentication) (*types.HeadlessAuthentication, error) + + // DeleteHeadlessAuthentication deletes a headless authentication from the backend by name. + DeleteHeadlessAuthentication(ctx context.Context, name string) error + types.WebSessionsGetter types.WebTokensGetter diff --git a/lib/services/local/headlessauthn.go b/lib/services/local/headlessauthn.go index a0ad206835a34..c8b0a70bb3cf4 100644 --- a/lib/services/local/headlessauthn.go +++ b/lib/services/local/headlessauthn.go @@ -24,13 +24,14 @@ import ( "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/utils" ) // CreateHeadlessAuthenticationStub creates a headless authentication stub in the backend. func (s *IdentityService) CreateHeadlessAuthenticationStub(ctx context.Context, name string) (*types.HeadlessAuthentication, error) { - expires := s.Clock().Now().Add(time.Minute) + expires := s.Clock().Now().Add(defaults.CallbackTimeout) headlessAuthn := &types.HeadlessAuthentication{ ResourceHeader: types.ResourceHeader{ Metadata: types.Metadata{ diff --git a/lib/services/presence.go b/lib/services/presence.go index f15415d757646..72558f78ebe45 100644 --- a/lib/services/presence.go +++ b/lib/services/presence.go @@ -61,13 +61,6 @@ type Presence interface { // specified duration with second resolution if it's >= 1 second. UpsertNode(ctx context.Context, server types.Server) (*types.KeepAlive, error) - // DELETE IN: 5.1.0 - // - // This logic has been moved to KeepAliveServer. - // - // KeepAliveNode updates node TTL in the storage - KeepAliveNode(ctx context.Context, h types.KeepAlive) error - // GetAuthServers returns a list of registered servers GetAuthServers() ([]types.Server, error) diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index a9c5d79151583..b4227adf1b0a8 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -3117,6 +3117,10 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro return nil, trace.Wrap(err) } + var authenticationClient interface { + AuthenticateSSHUser(ctx context.Context, req auth.AuthenticateSSHRequest) (*auth.SSHLoginResponse, error) + } = authClient + authReq := auth.AuthenticateUserRequest{ Username: req.User, PublicKey: req.PubKey, @@ -3136,11 +3140,20 @@ func (h *Handler) createSSHCert(w http.ResponseWriter, r *http.Request, p httpro case constants.SecondFactorWebauthn: // WebAuthn only supports this endpoint for headless login. authReq.HeadlessAuthenticationID = req.HeadlessAuthenticationID + + // create a new http client with a standard callback timeout. + authenticationClient, err = authClient.CloneHTTPClient( + auth.ClientParamTimeout(defaults.CallbackTimeout), + auth.ClientParamResponseHeaderTimeout(defaults.CallbackTimeout), + ) + if err != nil { + return nil, trace.Wrap(err) + } default: return nil, trace.AccessDenied("unsupported second factor type: %q", cap.GetSecondFactor()) } - loginResp, err := authClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{ + loginResp, err := authenticationClient.AuthenticateSSHUser(r.Context(), auth.AuthenticateSSHRequest{ AuthenticateUserRequest: authReq, CompatibilityMode: req.Compatibility, TTL: req.TTL, diff --git a/rfd/0105-headless-authentication.md b/rfd/0105-headless-authentication.md index 380dc8bfbf29d..aadb9a757fc7d 100644 --- a/rfd/0105-headless-authentication.md +++ b/rfd/0105-headless-authentication.md @@ -153,7 +153,7 @@ A request id will be derived from the client's public key so that an attacker ca Note: We could also use the public key directly (base64 encoded), but we choose to use a UUID to shorten the URL and improve its readability. -As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a 1 minute TTL, by which point the user should have completed the headless authentication flow. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. +As [explained above](#unauthenticated-headless-login-endpoint), the Auth server will write the request details to the backend under `/headless_authentication/` on demand. It will have a short TTL, matching the callback timeout of the request. The request will begin in the pending state. The Auth server then waits for the user to approve the authentication request using a resource watcher. #### Local authentication @@ -335,8 +335,8 @@ type AuthenticateUserRequest struct { Session *SessionCreds `json:"session,omitempty"` // ClientMetadata includes forwarded information about a client ClientMetadata *ForwardedClientMetadata `json:"client_metadata,omitempty"` - // Headless determines whether headless authentication will be used - Headless bool `json:"headless"` + // HeadlessAuthenticationID is the ID for a headless authentication resource. + HeadlessAuthenticationID string `json:"headless_authentication_id"` } ```