diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index f81b359198b10..af75674e08fb1 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -87,97 +87,97 @@ func NewAPIServer(config *APIConfig) (http.Handler, error) { srv.Router.UseRawPath = true // Kubernetes extensions - srv.POST("/:version/kube/csr", srv.withAuth(srv.processKubeCSR)) + srv.POST("/:version/kube/csr", srv.WithAuth(srv.processKubeCSR)) - srv.POST("/:version/authorities/:type", srv.withAuth(srv.upsertCertAuthority)) - srv.POST("/:version/authorities/:type/rotate", srv.withAuth(srv.rotateCertAuthority)) - srv.POST("/:version/authorities/:type/rotate/external", srv.withAuth(srv.rotateExternalCertAuthority)) - srv.DELETE("/:version/authorities/:type/:domain", srv.withAuth(srv.deleteCertAuthority)) - srv.GET("/:version/authorities/:type/:domain", srv.withAuth(srv.getCertAuthority)) - srv.GET("/:version/authorities/:type", srv.withAuth(srv.getCertAuthorities)) + srv.POST("/:version/authorities/:type", srv.WithAuth(srv.upsertCertAuthority)) + srv.POST("/:version/authorities/:type/rotate", srv.WithAuth(srv.rotateCertAuthority)) + srv.POST("/:version/authorities/:type/rotate/external", srv.WithAuth(srv.rotateExternalCertAuthority)) + srv.DELETE("/:version/authorities/:type/:domain", srv.WithAuth(srv.deleteCertAuthority)) + srv.GET("/:version/authorities/:type/:domain", srv.WithAuth(srv.getCertAuthority)) + srv.GET("/:version/authorities/:type", srv.WithAuth(srv.getCertAuthorities)) // Generating certificates for user and host authorities - srv.POST("/:version/ca/host/certs", srv.withAuth(srv.generateHostCert)) + srv.POST("/:version/ca/host/certs", srv.WithAuth(srv.generateHostCert)) // Operations on users - srv.GET("/:version/users", srv.withAuth(srv.getUsers)) - srv.GET("/:version/users/:user", srv.withAuth(srv.getUser)) - srv.DELETE("/:version/users/:user", srv.withAuth(srv.deleteUser)) // DELETE IN: 5.2 REST method is replaced by grpc method with context. + srv.GET("/:version/users", srv.WithAuth(srv.getUsers)) + srv.GET("/:version/users/:user", srv.WithAuth(srv.getUser)) + srv.DELETE("/:version/users/:user", srv.WithAuth(srv.deleteUser)) // DELETE IN: 5.2 REST method is replaced by grpc method with context. // Passwords and sessions - srv.POST("/:version/users", srv.withAuth(srv.upsertUser)) - srv.PUT("/:version/users/:user/web/password", srv.withAuth(srv.changePassword)) - srv.POST("/:version/users/:user/web/password/check", srv.withRate(srv.withAuth(srv.checkPassword))) - srv.POST("/:version/users/:user/web/sessions", srv.withAuth(srv.createWebSession)) - srv.POST("/:version/users/:user/web/authenticate", srv.withAuth(srv.authenticateWebUser)) - srv.POST("/:version/users/:user/ssh/authenticate", srv.withAuth(srv.authenticateSSHUser)) - srv.GET("/:version/users/:user/web/sessions/:sid", srv.withAuth(srv.getWebSession)) - srv.DELETE("/:version/users/:user/web/sessions/:sid", srv.withAuth(srv.deleteWebSession)) + srv.POST("/:version/users", srv.WithAuth(srv.upsertUser)) + srv.PUT("/:version/users/:user/web/password", srv.WithAuth(srv.changePassword)) + srv.POST("/:version/users/:user/web/password/check", srv.withRate(srv.WithAuth(srv.checkPassword))) + srv.POST("/:version/users/:user/web/sessions", srv.WithAuth(srv.createWebSession)) + srv.POST("/:version/users/:user/web/authenticate", srv.WithAuth(srv.authenticateWebUser)) + srv.POST("/:version/users/:user/ssh/authenticate", srv.WithAuth(srv.authenticateSSHUser)) + srv.GET("/:version/users/:user/web/sessions/:sid", srv.WithAuth(srv.getWebSession)) + srv.DELETE("/:version/users/:user/web/sessions/:sid", srv.WithAuth(srv.deleteWebSession)) // Servers and presence heartbeat - srv.POST("/:version/namespaces/:namespace/nodes/keepalive", srv.withAuth(srv.keepAliveNode)) - srv.POST("/:version/authservers", srv.withAuth(srv.upsertAuthServer)) - srv.GET("/:version/authservers", srv.withAuth(srv.getAuthServers)) - srv.POST("/:version/proxies", srv.withAuth(srv.upsertProxy)) - srv.GET("/:version/proxies", srv.withAuth(srv.getProxies)) - srv.DELETE("/:version/proxies", srv.withAuth(srv.deleteAllProxies)) - srv.DELETE("/:version/proxies/:name", srv.withAuth(srv.deleteProxy)) - srv.POST("/:version/tunnelconnections", srv.withAuth(srv.upsertTunnelConnection)) - srv.GET("/:version/tunnelconnections/:cluster", srv.withAuth(srv.getTunnelConnections)) - srv.GET("/:version/tunnelconnections", srv.withAuth(srv.getAllTunnelConnections)) - srv.DELETE("/:version/tunnelconnections/:cluster/:conn", srv.withAuth(srv.deleteTunnelConnection)) - srv.DELETE("/:version/tunnelconnections/:cluster", srv.withAuth(srv.deleteTunnelConnections)) - srv.DELETE("/:version/tunnelconnections", srv.withAuth(srv.deleteAllTunnelConnections)) + srv.POST("/:version/namespaces/:namespace/nodes/keepalive", srv.WithAuth(srv.keepAliveNode)) + srv.POST("/:version/authservers", srv.WithAuth(srv.upsertAuthServer)) + srv.GET("/:version/authservers", srv.WithAuth(srv.getAuthServers)) + srv.POST("/:version/proxies", srv.WithAuth(srv.upsertProxy)) + srv.GET("/:version/proxies", srv.WithAuth(srv.getProxies)) + srv.DELETE("/:version/proxies", srv.WithAuth(srv.deleteAllProxies)) + srv.DELETE("/:version/proxies/:name", srv.WithAuth(srv.deleteProxy)) + srv.POST("/:version/tunnelconnections", srv.WithAuth(srv.upsertTunnelConnection)) + srv.GET("/:version/tunnelconnections/:cluster", srv.WithAuth(srv.getTunnelConnections)) + srv.GET("/:version/tunnelconnections", srv.WithAuth(srv.getAllTunnelConnections)) + srv.DELETE("/:version/tunnelconnections/:cluster/:conn", srv.WithAuth(srv.deleteTunnelConnection)) + srv.DELETE("/:version/tunnelconnections/:cluster", srv.WithAuth(srv.deleteTunnelConnections)) + srv.DELETE("/:version/tunnelconnections", srv.WithAuth(srv.deleteAllTunnelConnections)) // Remote clusters - srv.POST("/:version/remoteclusters", srv.withAuth(srv.createRemoteCluster)) - srv.GET("/:version/remoteclusters/:cluster", srv.withAuth(srv.getRemoteCluster)) - srv.GET("/:version/remoteclusters", srv.withAuth(srv.getRemoteClusters)) - srv.DELETE("/:version/remoteclusters/:cluster", srv.withAuth(srv.deleteRemoteCluster)) - srv.DELETE("/:version/remoteclusters", srv.withAuth(srv.deleteAllRemoteClusters)) + srv.POST("/:version/remoteclusters", srv.WithAuth(srv.createRemoteCluster)) + srv.GET("/:version/remoteclusters/:cluster", srv.WithAuth(srv.getRemoteCluster)) + srv.GET("/:version/remoteclusters", srv.WithAuth(srv.getRemoteClusters)) + srv.DELETE("/:version/remoteclusters/:cluster", srv.WithAuth(srv.deleteRemoteCluster)) + srv.DELETE("/:version/remoteclusters", srv.WithAuth(srv.deleteAllRemoteClusters)) // Reverse tunnels - srv.POST("/:version/reversetunnels", srv.withAuth(srv.upsertReverseTunnel)) - srv.GET("/:version/reversetunnels", srv.withAuth(srv.getReverseTunnels)) - srv.DELETE("/:version/reversetunnels/:domain", srv.withAuth(srv.deleteReverseTunnel)) + srv.POST("/:version/reversetunnels", srv.WithAuth(srv.upsertReverseTunnel)) + srv.GET("/:version/reversetunnels", srv.WithAuth(srv.getReverseTunnels)) + srv.DELETE("/:version/reversetunnels/:domain", srv.WithAuth(srv.deleteReverseTunnel)) // trusted clusters - srv.POST("/:version/trustedclusters/validate", srv.withAuth(srv.validateTrustedCluster)) + srv.POST("/:version/trustedclusters/validate", srv.WithAuth(srv.validateTrustedCluster)) // Tokens - srv.POST("/:version/tokens/register", srv.withAuth(srv.registerUsingToken)) + srv.POST("/:version/tokens/register", srv.WithAuth(srv.registerUsingToken)) // Active sessions - srv.GET("/:version/namespaces/:namespace/sessions/:id/stream", srv.withAuth(srv.getSessionChunk)) - srv.GET("/:version/namespaces/:namespace/sessions/:id/events", srv.withAuth(srv.getSessionEvents)) + srv.GET("/:version/namespaces/:namespace/sessions/:id/stream", srv.WithAuth(srv.getSessionChunk)) + srv.GET("/:version/namespaces/:namespace/sessions/:id/events", srv.WithAuth(srv.getSessionEvents)) // DELETE IN 12.0.0 - srv.POST("/:version/namespaces/:namespace/sessions", srv.withAuth(srv.createSession)) - srv.PUT("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.updateSession)) - srv.DELETE("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.deleteSession)) - srv.GET("/:version/namespaces/:namespace/sessions", srv.withAuth(srv.getSessions)) - srv.GET("/:version/namespaces/:namespace/sessions/:id", srv.withAuth(srv.getSession)) + srv.POST("/:version/namespaces/:namespace/sessions", srv.WithAuth(srv.createSession)) + srv.PUT("/:version/namespaces/:namespace/sessions/:id", srv.WithAuth(srv.updateSession)) + srv.DELETE("/:version/namespaces/:namespace/sessions/:id", srv.WithAuth(srv.deleteSession)) + srv.GET("/:version/namespaces/:namespace/sessions", srv.WithAuth(srv.getSessions)) + srv.GET("/:version/namespaces/:namespace/sessions/:id", srv.WithAuth(srv.getSession)) // Namespaces - srv.POST("/:version/namespaces", srv.withAuth(srv.upsertNamespace)) - srv.GET("/:version/namespaces", srv.withAuth(srv.getNamespaces)) - srv.GET("/:version/namespaces/:namespace", srv.withAuth(srv.getNamespace)) - srv.DELETE("/:version/namespaces/:namespace", srv.withAuth(srv.deleteNamespace)) + srv.POST("/:version/namespaces", srv.WithAuth(srv.upsertNamespace)) + srv.GET("/:version/namespaces", srv.WithAuth(srv.getNamespaces)) + srv.GET("/:version/namespaces/:namespace", srv.WithAuth(srv.getNamespace)) + srv.DELETE("/:version/namespaces/:namespace", srv.WithAuth(srv.deleteNamespace)) // cluster configuration - srv.GET("/:version/configuration/name", srv.withAuth(srv.getClusterName)) - srv.POST("/:version/configuration/name", srv.withAuth(srv.setClusterName)) - srv.GET("/:version/configuration/static_tokens", srv.withAuth(srv.getStaticTokens)) - srv.DELETE("/:version/configuration/static_tokens", srv.withAuth(srv.deleteStaticTokens)) - srv.POST("/:version/configuration/static_tokens", srv.withAuth(srv.setStaticTokens)) + srv.GET("/:version/configuration/name", srv.WithAuth(srv.getClusterName)) + srv.POST("/:version/configuration/name", srv.WithAuth(srv.setClusterName)) + srv.GET("/:version/configuration/static_tokens", srv.WithAuth(srv.getStaticTokens)) + srv.DELETE("/:version/configuration/static_tokens", srv.WithAuth(srv.deleteStaticTokens)) + srv.POST("/:version/configuration/static_tokens", srv.WithAuth(srv.setStaticTokens)) // SSO validation handlers - srv.POST("/:version/oidc/requests/validate", srv.withAuth(srv.validateOIDCAuthCallback)) - srv.POST("/:version/saml/requests/validate", srv.withAuth(srv.validateSAMLResponse)) - srv.POST("/:version/github/requests/validate", srv.withAuth(srv.validateGithubAuthCallback)) + srv.POST("/:version/oidc/requests/validate", srv.WithAuth(srv.validateOIDCAuthCallback)) + srv.POST("/:version/saml/requests/validate", srv.WithAuth(srv.validateSAMLResponse)) + srv.POST("/:version/github/requests/validate", srv.WithAuth(srv.validateGithubAuthCallback)) // Audit logs AKA events - srv.GET("/:version/events", srv.withAuth(srv.searchEvents)) - srv.GET("/:version/events/session", srv.withAuth(srv.searchSessionEvents)) + srv.GET("/:version/events", srv.WithAuth(srv.searchEvents)) + srv.GET("/:version/events/session", srv.WithAuth(srv.searchSessionEvents)) if config.PluginRegistry != nil { if err := config.PluginRegistry.RegisterAuthWebHandlers(&srv); err != nil { @@ -195,7 +195,7 @@ func NewAPIServer(config *APIConfig) (http.Handler, error) { // HandlerWithAuthFunc is http handler with passed auth context type HandlerWithAuthFunc func(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) -func (s *APIServer) withAuth(handler HandlerWithAuthFunc) httprouter.Handle { +func (s *APIServer) WithAuth(handler HandlerWithAuthFunc) httprouter.Handle { const accessDeniedMsg = "auth API: access denied " return httplib.MakeHandler(func(w http.ResponseWriter, r *http.Request, p httprouter.Params) (interface{}, error) { // HTTPS server expects auth context to be set by the auth middleware @@ -868,32 +868,8 @@ func (s *APIServer) getSession(auth ClientI, w http.ResponseWriter, r *http.Requ return se, nil } -type validateOIDCAuthCallbackReq struct { - Query url.Values `json:"query"` -} - -// oidcAuthRawResponse is returned when auth server validated callback parameters -// returned from OIDC provider -type oidcAuthRawResponse struct { - // Username is authenticated teleport username - Username string `json:"username"` - // Identity contains validated OIDC identity - Identity types.ExternalIdentity `json:"identity"` - // Web session will be generated by auth server if requested in OIDCAuthRequest - Session json.RawMessage `json:"session,omitempty"` - // Cert will be generated by certificate authority - Cert []byte `json:"cert,omitempty"` - // TLSCert is PEM encoded TLS certificate - TLSCert []byte `json:"tls_cert,omitempty"` - // Req is original oidc auth request - Req OIDCAuthRequest `json:"req"` - // HostSigners is a list of signing host public keys - // trusted by proxy, used in console login - HostSigners []json.RawMessage `json:"host_signers"` -} - func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - var req *validateOIDCAuthCallbackReq + var req *ValidateOIDCAuthCallbackReq if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } @@ -901,7 +877,7 @@ func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter if err != nil { return nil, trace.Wrap(err) } - raw := oidcAuthRawResponse{ + raw := OIDCAuthRawResponse{ Username: response.Username, Identity: response.Identity, Cert: response.Cert, @@ -926,33 +902,8 @@ func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter return &raw, nil } -type validateSAMLResponseReq struct { - Response string `json:"response"` - ConnectorID string `json:"connector_id,omitempty"` -} - -// samlAuthRawResponse is returned when auth server validated callback parameters -// returned from SAML provider -type samlAuthRawResponse struct { - // Username is authenticated teleport username - Username string `json:"username"` - // Identity contains validated OIDC identity - Identity types.ExternalIdentity `json:"identity"` - // Web session will be generated by auth server if requested in OIDCAuthRequest - Session json.RawMessage `json:"session,omitempty"` - // Cert will be generated by certificate authority - Cert []byte `json:"cert,omitempty"` - // Req is original oidc auth request - Req SAMLAuthRequest `json:"req"` - // HostSigners is a list of signing host public keys - // trusted by proxy, used in console login - HostSigners []json.RawMessage `json:"host_signers"` - // TLSCert is TLS certificate authority certificate - TLSCert []byte `json:"tls_cert,omitempty"` -} - func (s *APIServer) validateSAMLResponse(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - var req *validateSAMLResponseReq + var req *ValidateSAMLResponseReq if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } @@ -960,7 +911,7 @@ func (s *APIServer) validateSAMLResponse(auth ClientI, w http.ResponseWriter, r if err != nil { return nil, trace.Wrap(err) } - raw := samlAuthRawResponse{ + raw := SAMLAuthRawResponse{ Username: response.Username, Identity: response.Identity, Cert: response.Cert, diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 8e612a93044a3..e0e3440bc08de 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -35,20 +35,16 @@ import ( "math/big" insecurerand "math/rand" "net" - "net/url" "os" "strings" "sync" "time" - "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" - "github.com/coreos/go-oidc/oidc" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/prometheus/client_golang/prometheus" - saml2 "github.com/russellhaering/gosaml2" "github.com/sirupsen/logrus" "go.opentelemetry.io/otel/exporters/otlp/otlptrace" "golang.org/x/crypto/ssh" @@ -252,18 +248,15 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { Authority: cfg.Authority, AuthServiceName: cfg.AuthServiceName, ServerID: cfg.HostUUID, - oidcClients: make(map[string]*oidcClient), - samlProviders: make(map[string]*samlProvider), githubClients: make(map[string]*githubClient), cancelFunc: cancelFunc, closeCtx: closeCtx, emitter: cfg.Emitter, streamer: cfg.Streamer, - unstable: local.NewUnstableService(cfg.Backend, cfg.AssertionReplayService), + Unstable: local.NewUnstableService(cfg.Backend, cfg.AssertionReplayService), Services: services, Cache: services, keyStore: keyStore, - getClaimsFun: getClaims, inventory: inventory.NewController(cfg.Presence, inventory.WithAuthServerID(cfg.HostUUID)), traceClient: cfg.TraceClient, fips: cfg.FIPS, @@ -299,6 +292,22 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { ) } } + // Plug in SAML auth service + sas := NewSAMLAuthService(&SAMLAuthServiceConfig{ + Auth: &as, + Emitter: as.emitter, + AssertionReplayService: as.Unstable.AssertionReplayService, + }) + as.SetSAMLService(sas) + + oas, err := NewOIDCAuthService(&OIDCAuthServiceConfig{ + Auth: &as, + Emitter: as.emitter, + }) + if err != nil { + return nil, trace.Wrap(err) + } + as.SetOIDCService(oas) return &as, nil } @@ -413,10 +422,7 @@ var ( // - same for users and their sessions // - checks public keys to see if they're signed by it (can be trusted or not) type Server struct { - lock sync.RWMutex - // oidcClients is a map from authID & proxyAddr -> oidcClient - oidcClients map[string]*oidcClient - samlProviders map[string]*samlProvider + lock sync.RWMutex githubClients map[string]*githubClient clock clockwork.Clock bk backend.Backend @@ -424,6 +430,9 @@ type Server struct { closeCtx context.Context cancelFunc context.CancelFunc + samlAuthService SAMLService + oidcAuthService OIDCService + sshca.Authority // AuthServiceName is a human-readable name of this CA. If several Auth services are running @@ -434,9 +443,9 @@ type Server struct { // ServerID is the server ID of this auth server. ServerID string - // unstable implements unstable backend methods not suitable + // Unstable implements Unstable backend methods not suitable // for inclusion in Services. - unstable local.UnstableService + Unstable local.UnstableService // Services encapsulate services - provisioner, trust, etc. used by the auth // server in a separate structure. Reads through Services hit the backend. @@ -473,9 +482,6 @@ type Server struct { // lockWatcher is a lock watcher, used to verify cert generation requests. lockWatcher *services.LockWatcher - // getClaimsFun is used in tests for overriding the implementation of getClaims method used in OIDC. - getClaimsFun func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) - inventory *inventory.Controller // githubOrgSSOCache is used to cache whether Github organizations use @@ -501,6 +507,20 @@ type Server struct { loadAllCAs bool } +// SetSAMLService registers svc as the SAMLService that provides the SAML +// connector implementation. If a SAMLService has already been registered, this +// will override the previous registration. +func (a *Server) SetSAMLService(svc SAMLService) { + a.samlAuthService = svc +} + +// SetOIDCService registers svc as the OIDCService that provides the OIDC +// connector implementation. If a OIDCService has already been registered, this +// will override the previous registration. +func (a *Server) SetOIDCService(svc OIDCService) { + a.oidcAuthService = svc +} + func (a *Server) CloseContext() context.Context { return a.closeCtx } @@ -2303,19 +2323,13 @@ func (a *Server) CreateWebSession(ctx context.Context, user string) (types.WebSe if err != nil { return nil, trace.Wrap(err) } - sess, err := a.NewWebSession(ctx, types.NewWebSessionRequest{ + session, err := a.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user, Roles: u.GetRoles(), Traits: u.GetTraits(), LoginTime: a.clock.Now().UTC(), }) - if err != nil { - return nil, trace.Wrap(err) - } - if err := a.upsertWebSession(ctx, user, sess); err != nil { - return nil, trace.Wrap(err) - } - return sess, nil + return session, trace.Wrap(err) } // GenerateToken generates multi-purpose authentication token. @@ -2564,11 +2578,11 @@ func (a *Server) GenerateHostCerts(ctx context.Context, req *proto.HostCertsRequ // instances to prove that they hold a given system role. // DELETE IN: 12.0 (deprecated in v11, but required for back-compat with v10 clients) func (a *Server) UnstableAssertSystemRole(ctx context.Context, req proto.UnstableSystemRoleAssertion) error { - return trace.Wrap(a.unstable.AssertSystemRole(ctx, req)) + return trace.Wrap(a.Unstable.AssertSystemRole(ctx, req)) } func (a *Server) UnstableGetSystemRoleAssertions(ctx context.Context, serverID string, assertionID string) (proto.UnstableSystemRoleAssertionSet, error) { - set, err := a.unstable.GetSystemRoleAssertions(ctx, serverID, assertionID) + set, err := a.Unstable.GetSystemRoleAssertions(ctx, serverID, assertionID) return set, trace.Wrap(err) } @@ -4108,23 +4122,6 @@ const ( SessionTokenBytes = 32 ) -// oidcClient is internal structure that stores OIDC client and its config -type oidcClient struct { - client *oidc.Client - connector types.OIDCConnector - // syncCtx controls the provider sync goroutine. - syncCtx context.Context - syncCancel context.CancelFunc - // firstSync will be closed once the first provider sync succeeds - firstSync chan struct{} -} - -// samlProvider is internal structure that stores SAML client and its config -type samlProvider struct { - provider *saml2.SAMLServiceProvider - connector types.SAMLConnector -} - // githubClient is internal structure that stores Github OAuth 2client and its config type githubClient struct { client *oauth2.Client @@ -4162,19 +4159,6 @@ func oauth2ConfigsEqual(a, b oauth2.Config) bool { return true } -// isHTTPS checks if the scheme for a URL is https or not. -func isHTTPS(u string) error { - earl, err := url.Parse(u) - if err != nil { - return trace.Wrap(err) - } - if earl.Scheme != "https" { - return trace.BadParameter("expected scheme https, got %q", earl.Scheme) - } - - return nil -} - // WithClusterCAs returns a TLS hello callback that returns a copy of the provided // TLS config with client CAs pool of the specified cluster. func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) { diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 41deff5552065..38aa1a3b172dc 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -21,7 +21,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509/pkix" - "encoding/json" "errors" "fmt" mathrand "math/rand" @@ -31,7 +30,6 @@ import ( "testing" "time" - "github.com/coreos/go-oidc/jose" "github.com/google/go-cmp/cmp" "github.com/google/uuid" reporting "github.com/gravitational/reporting/types" @@ -91,10 +89,13 @@ func newTestPack( if err != nil { return p, trace.Wrap(err) } + + p.mockEmitter = &eventstest.MockEmitter{} authConfig := &InitConfig{ Backend: p.bk, ClusterName: p.clusterName, Authority: testauthority.New(), + Emitter: p.mockEmitter, SkipPeriodicOperations: true, KeyStoreConfig: keystore.Config{ Software: keystore.SoftwareConfig{ @@ -168,8 +169,6 @@ func newTestPack( return p, trace.Wrap(err) } - p.mockEmitter = &eventstest.MockEmitter{} - p.a.emitter = p.mockEmitter return p, nil } @@ -757,107 +756,6 @@ func TestGenerateTokenEventsEmitted(t *testing.T) { require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.TrustedClusterTokenCreateEvent) } -func TestValidateACRValues(t *testing.T) { - s := newAuthSuite(t) - - tests := []struct { - comment string - inIDToken string - inACRValue string - inACRProvider string - outIsValid require.ErrorAssertionFunc - }{ - { - "0 - default, acr values match", - ` -{ - "acr": "foo", - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo", - "", - require.NoError, - }, - { - "1 - default, acr values do not match", - ` -{ - "acr": "foo", - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "bar", - "", - require.Error, - }, - { - "2 - netiq, acr values match", - ` -{ - "acr": { - "values": [ - "foo/bar/baz" - ] - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.NoError, - }, - { - "3 - netiq, invalid format", - ` -{ - "acr": { - "values": "foo/bar/baz" - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.Error, - }, - { - "4 - netiq, invalid value", - ` -{ - "acr": { - "values": [ - "foo/bar/baz/qux" - ] - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.Error, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.comment, func(t *testing.T) { - t.Parallel() - var claims jose.Claims - err := json.Unmarshal([]byte(tt.inIDToken), &claims) - require.NoError(t, err) - - err = s.a.validateACRValues(tt.inACRValue, tt.inACRProvider, claims) - tt.outIsValid(t, err) - }) - } -} - func TestUpdateConfig(t *testing.T) { t.Parallel() s := newAuthSuite(t) @@ -1038,18 +936,21 @@ func TestGithubConnectorCRUDEventsEmitted(t *testing.T) { require.NoError(t, err) err = s.a.upsertGithubConnector(ctx, github) require.NoError(t, err) + require.IsType(t, &apievents.GithubConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.GithubConnectorCreatedEvent) s.mockEmitter.Reset() // test github update event err = s.a.upsertGithubConnector(ctx, github) require.NoError(t, err) + require.IsType(t, &apievents.GithubConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.GithubConnectorCreatedEvent) s.mockEmitter.Reset() // test github delete event err = s.a.deleteGithubConnector(ctx, "test") require.NoError(t, err) + require.IsType(t, &apievents.GithubConnectorDelete{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.GithubConnectorDeletedEvent) } @@ -1073,18 +974,21 @@ func TestOIDCConnectorCRUDEventsEmitted(t *testing.T) { require.NoError(t, err) err = s.a.UpsertOIDCConnector(ctx, oidc) require.NoError(t, err) + require.IsType(t, &apievents.OIDCConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.OIDCConnectorCreatedEvent) s.mockEmitter.Reset() // test oidc update event err = s.a.UpsertOIDCConnector(ctx, oidc) require.NoError(t, err) + require.IsType(t, &apievents.OIDCConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.OIDCConnectorCreatedEvent) s.mockEmitter.Reset() // test oidc delete event err = s.a.DeleteOIDCConnector(ctx, "test") require.NoError(t, err) + require.IsType(t, &apievents.OIDCConnectorDelete{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.OIDCConnectorDeletedEvent) } @@ -1133,18 +1037,21 @@ func TestSAMLConnectorCRUDEventsEmitted(t *testing.T) { err = s.a.UpsertSAMLConnector(ctx, saml) require.NoError(t, err) + require.IsType(t, &apievents.SAMLConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.SAMLConnectorCreatedEvent) s.mockEmitter.Reset() // test saml update event err = s.a.UpsertSAMLConnector(ctx, saml) require.NoError(t, err) + require.IsType(t, &apievents.SAMLConnectorCreate{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.SAMLConnectorCreatedEvent) s.mockEmitter.Reset() // test saml delete event err = s.a.DeleteSAMLConnector(ctx, "test") require.NoError(t, err) + require.IsType(t, &apievents.SAMLConnectorDelete{}, s.mockEmitter.LastEvent()) require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.SAMLConnectorDeletedEvent) } diff --git a/lib/auth/auth_with_roles.go b/lib/auth/auth_with_roles.go index 7f87aadd454ca..621aa784638cb 100644 --- a/lib/auth/auth_with_roles.go +++ b/lib/auth/auth_with_roles.go @@ -2935,7 +2935,7 @@ func (a *ServerWithRoles) UpsertSAMLConnector(ctx context.Context, connector typ return trace.Wrap(err) } if !modules.GetModules().Features().SAML { - return trace.AccessDenied("SAML is only available in enterprise subscriptions") + return trace.Wrap(ErrSAMLRequiresEnterprise) } return a.authServer.UpsertSAMLConnector(ctx, connector) } diff --git a/lib/auth/clt.go b/lib/auth/clt.go index 03bbd21c61848..ad2b454e4f60a 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -959,13 +959,13 @@ func (c *Client) GenerateHostCert( // ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), validateOIDCAuthCallbackReq{ + out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ Query: q, }) if err != nil { return nil, trace.Wrap(err) } - var rawResponse *oidcAuthRawResponse + var rawResponse *OIDCAuthRawResponse if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { return nil, trace.Wrap(err) } @@ -996,14 +996,14 @@ func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O // ValidateSAMLResponse validates response returned by SAML identity provider func (c *Client) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), validateSAMLResponseReq{ + out, err := c.PostJSON(ctx, c.Endpoint("saml", "requests", "validate"), ValidateSAMLResponseReq{ Response: re, ConnectorID: connectorID, }) if err != nil { return nil, trace.Wrap(err) } - var rawResponse *samlAuthRawResponse + var rawResponse *SAMLAuthRawResponse if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/github.go b/lib/auth/github.go index 92ff368191da4..4a4bc4ab94977 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -346,16 +346,16 @@ func GithubAuthRequestFromProto(req *types.GithubAuthRequest) GithubAuthRequest } type githubManager interface { - validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) - newSSODiagContext(authKind string) *ssoDiagContext + validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) } // ValidateGithubAuthCallback validates Github auth callback redirect func (a *Server) ValidateGithubAuthCallback(ctx context.Context, q url.Values) (*GithubAuthResponse, error) { - return validateGithubAuthCallbackHelper(ctx, a, q, a.emitter) + diagCtx := NewSSODiagContext(types.KindGithub, a) + return validateGithubAuthCallbackHelper(ctx, a, diagCtx, q, a.emitter) } -func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, q url.Values, emitter apievents.Emitter) (*GithubAuthResponse, error) { +func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, diagCtx *SSODiagContext, q url.Values, emitter apievents.Emitter) (*GithubAuthResponse, error) { event := &apievents.UserLogin{ Metadata: apievents.Metadata{ Type: events.UserLoginEvent, @@ -363,14 +363,12 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, q ur Method: events.LoginMethodGithub, } - diagCtx := m.newSSODiagContext(types.KindGithub) - auth, err := m.validateGithubAuthCallback(ctx, diagCtx, q) - diagCtx.info.Error = trace.UserMessage(err) + diagCtx.Info.Error = trace.UserMessage(err) - diagCtx.writeToBackend(ctx) + diagCtx.WriteToBackend(ctx) - claims := diagCtx.info.GithubClaims + claims := diagCtx.Info.GithubClaims if claims != nil { attributes, err := apievents.EncodeMapStrings(claims.OrganizationToTeams) if err != nil { @@ -383,7 +381,7 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, q ur if err != nil { event.Code = events.UserSSOLoginFailureCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginFailureCode } event.Status.Success = false @@ -396,7 +394,7 @@ func validateGithubAuthCallbackHelper(ctx context.Context, m githubManager, q ur return nil, trace.Wrap(err) } event.Code = events.UserSSOLoginCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginCode } event.Status.Success = true @@ -490,17 +488,17 @@ func (a *Server) getGithubOAuth2Client(connector types.GithubConnector) (*oauth2 } // ValidateGithubAuthCallback validates Github auth callback redirect -func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { +func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { logger := log.WithFields(logrus.Fields{trace.Component: "github"}) if errParam := q.Get("error"); errParam != "" { // try to find request so the error gets logged against it. state := q.Get("state") if state != "" { - diagCtx.requestID = state + diagCtx.RequestID = state req, err := a.Services.GetGithubAuthRequest(ctx, state) if err == nil { - diagCtx.info.TestFlow = req.SSOTestFlow + diagCtx.Info.TestFlow = req.SSOTestFlow } } @@ -521,20 +519,20 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia oauthErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q) return nil, trace.WithUserMessage(oauthErr, "Invalid parameters received from Github.") } - diagCtx.requestID = stateToken + diagCtx.RequestID = stateToken req, err := a.Services.GetGithubAuthRequest(ctx, stateToken) if err != nil { return nil, trace.Wrap(err, "Failed to get OIDC Auth Request.") } - diagCtx.info.TestFlow = req.SSOTestFlow + diagCtx.Info.TestFlow = req.SSOTestFlow connector, client, err := a.getGithubConnectorAndClient(ctx, *req) if err != nil { return nil, trace.Wrap(err, "Failed to get Github connector and client.") } - diagCtx.info.GithubTeamsToLogins = connector.GetTeamsToLogins() - diagCtx.info.GithubTeamsToRoles = connector.GetTeamsToRoles() + diagCtx.Info.GithubTeamsToLogins = connector.GetTeamsToLogins() + diagCtx.Info.GithubTeamsToRoles = connector.GetTeamsToRoles() logger.Debugf("Connector %q teams to logins: %v, roles: %v", connector.GetName(), connector.GetTeamsToLogins(), connector.GetTeamsToRoles()) // exchange the authorization code received by the callback for an access token @@ -543,7 +541,7 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia return nil, trace.Wrap(err, "Requesting Github OAuth2 token failed.") } - diagCtx.info.GithubTokenInfo = &types.GithubTokenInfo{ + diagCtx.Info.GithubTokenInfo = &types.GithubTokenInfo{ TokenType: token.TokenType, Expires: int64(token.Expires), Scope: token.Scope, @@ -589,7 +587,7 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia if err != nil { return nil, trace.Wrap(err, "Failed to query Github API for user claims.") } - diagCtx.info.GithubClaims = claims + diagCtx.Info.GithubClaims = claims // Calculate (figure out name, roles, traits, session TTL) of user and // create the user in the backend. @@ -598,14 +596,14 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia return nil, trace.Wrap(err, "Failed to calculate user attributes.") } - diagCtx.info.CreateUserParams = &types.CreateUserParams{ - ConnectorName: params.connectorName, - Username: params.username, - KubeGroups: params.kubeGroups, - KubeUsers: params.kubeUsers, - Roles: params.roles, - Traits: params.traits, - SessionTTL: types.Duration(params.sessionTTL), + diagCtx.Info.CreateUserParams = &types.CreateUserParams{ + ConnectorName: params.ConnectorName, + Username: params.Username, + KubeGroups: params.KubeGroups, + KubeUsers: params.KubeUsers, + Roles: params.Roles, + Traits: params.Traits, + SessionTTL: types.Duration(params.SessionTTL), } user, err := a.createGithubUser(ctx, params, req.SSOTestFlow) @@ -617,25 +615,25 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia auth := GithubAuthResponse{ Req: GithubAuthRequestFromProto(req), Identity: types.ExternalIdentity{ - ConnectorID: params.connectorName, - Username: params.username, + ConnectorID: params.ConnectorName, + Username: params.Username, }, Username: user.GetName(), } // In test flow skip signing and creating web sessions. if req.SSOTestFlow { - diagCtx.info.Success = true + diagCtx.Info.Success = true return &auth, nil } // If the request is coming from a browser, create a web session. if req.CreateWebSession { - session, err := a.createWebSession(ctx, types.NewWebSessionRequest{ + session, err := a.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user.GetName(), Roles: user.GetRoles(), Traits: user.GetTraits(), - SessionTTL: params.sessionTTL, + SessionTTL: params.SessionTTL, LoginTime: a.clock.Now().UTC(), }) if err != nil { @@ -647,7 +645,7 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, + sshCert, tlsCert, err := a.CreateSessionCert(user, params.SessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, req.KubernetesCluster, keys.AttestationStatementFromProto(req.AttestationStatement)) if err != nil { return nil, trace.Wrap(err, "Failed to create session certificate.") @@ -675,89 +673,89 @@ func (a *Server) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDia return &auth, nil } -// createUserParams is a set of parameters used to create a user for an +// CreateUserParams is a set of parameters used to create a user for an // external identity provider. -type createUserParams struct { - // connectorName is the name of the connector for the identity provider. - connectorName string +type CreateUserParams struct { + // ConnectorName is the name of the connector for the identity provider. + ConnectorName string - // username is the Teleport user name . - username string + // Username is the Teleport user name . + Username string - // kubeGroups is the list of Kubernetes groups this user belongs to. - kubeGroups []string + // KubeGroups is the list of Kubernetes groups this user belongs to. + KubeGroups []string - // kubeUsers is the list of Kubernetes users this user belongs to. - kubeUsers []string + // KubeUsers is the list of Kubernetes users this user belongs to. + KubeUsers []string - // roles is the list of roles this user is assigned to. - roles []string + // Roles is the list of Roles this user is assigned to. + Roles []string - // traits is the list of traits for this user. - traits map[string][]string + // Traits is the list of Traits for this user. + Traits map[string][]string - // sessionTTL is how long this session will last. - sessionTTL time.Duration + // SessionTTL is how long this session will last. + SessionTTL time.Duration } -func (a *Server) calculateGithubUser(connector types.GithubConnector, claims *types.GithubClaims, request *types.GithubAuthRequest) (*createUserParams, error) { - p := createUserParams{ - connectorName: connector.GetName(), - username: claims.Username, +func (a *Server) calculateGithubUser(connector types.GithubConnector, claims *types.GithubClaims, request *types.GithubAuthRequest) (*CreateUserParams, error) { + p := CreateUserParams{ + ConnectorName: connector.GetName(), + Username: claims.Username, } // Calculate logins, kubegroups, roles, and traits. - p.roles, p.kubeGroups, p.kubeUsers = connector.MapClaims(*claims) - if len(p.roles) == 0 { + p.Roles, p.KubeGroups, p.KubeUsers = connector.MapClaims(*claims) + if len(p.Roles) == 0 { return nil, trace.Wrap(ErrGithubNoTeams) } - p.traits = map[string][]string{ - constants.TraitLogins: {p.username}, - constants.TraitKubeGroups: p.kubeGroups, - constants.TraitKubeUsers: p.kubeUsers, + p.Traits = map[string][]string{ + constants.TraitLogins: {p.Username}, + constants.TraitKubeGroups: p.KubeGroups, + constants.TraitKubeUsers: p.KubeUsers, teleport.TraitTeams: claims.Teams, } // Pick smaller for role: session TTL from role or requested TTL. - roles, err := services.FetchRoles(p.roles, a, p.traits) + roles, err := services.FetchRoles(p.Roles, a, p.Traits) if err != nil { return nil, trace.Wrap(err) } roleTTL := roles.AdjustSessionTTL(apidefaults.MaxCertDuration) - p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + p.SessionTTL = utils.MinTTL(roleTTL, request.CertTTL) return &p, nil } -func (a *Server) createGithubUser(ctx context.Context, p *createUserParams, dryRun bool) (types.User, error) { +func (a *Server) createGithubUser(ctx context.Context, p *CreateUserParams, dryRun bool) (types.User, error) { log.WithFields(logrus.Fields{trace.Component: "github"}).Debugf( "Generating dynamic GitHub identity %v/%v with roles: %v. Dry run: %v.", - p.connectorName, p.username, p.roles, dryRun) + p.ConnectorName, p.Username, p.Roles, dryRun) - expires := a.GetClock().Now().UTC().Add(p.sessionTTL) + expires := a.GetClock().Now().UTC().Add(p.SessionTTL) user := &types.UserV2{ Kind: types.KindUser, Version: types.V2, Metadata: types.Metadata{ - Name: p.username, + Name: p.Username, Namespace: apidefaults.Namespace, Expires: &expires, }, Spec: types.UserSpecV2{ - Roles: p.roles, - Traits: p.traits, + Roles: p.Roles, + Traits: p.Traits, GithubIdentities: []types.ExternalIdentity{{ - ConnectorID: p.connectorName, - Username: p.username, + ConnectorID: p.ConnectorName, + Username: p.Username, }}, CreatedBy: types.CreatedBy{ User: types.UserRef{Name: teleport.UserSystem}, Time: a.GetClock().Now().UTC(), Connector: &types.ConnectorRef{ Type: constants.Github, - ID: p.connectorName, - Identity: p.username, + ID: p.ConnectorName, + Identity: p.Username, }, }, }, @@ -767,7 +765,7 @@ func (a *Server) createGithubUser(ctx context.Context, p *createUserParams, dryR return user, nil } - existingUser, err := a.Services.GetUser(p.username, false) + existingUser, err := a.Services.GetUser(p.Username, false) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) } diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index 3a9a9c2c3f126..f0194286de9fd 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -120,11 +120,11 @@ func TestCreateGithubUser(t *testing.T) { tt := setupGithubContext(ctx, t) // Dry-run creation of Github user. - user, err := tt.a.createGithubUser(context.Background(), &createUserParams{ - connectorName: "github", - username: "foo@example.com", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + user, err := tt.a.createGithubUser(context.Background(), &CreateUserParams{ + ConnectorName: "github", + Username: "foo@example.com", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, true) require.NoError(t, err) require.Equal(t, user.GetName(), "foo@example.com") @@ -134,11 +134,11 @@ func TestCreateGithubUser(t *testing.T) { require.Error(t, err) // Create GitHub user with 1 minute expiry. - _, err = tt.a.createGithubUser(context.Background(), &createUserParams{ - connectorName: "github", - username: "foo", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + _, err = tt.a.createGithubUser(context.Background(), &CreateUserParams{ + ConnectorName: "github", + Username: "foo", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, false) require.NoError(t, err) @@ -193,81 +193,69 @@ func TestValidateGithubAuthCallbackEventsEmitted(t *testing.T) { } ssoDiagInfoCalls := 0 - - m := &mockedGithubManager{} - m.createSSODiagnosticInfo = func(ctx context.Context, authKind string, authRequestID string, info types.SSODiagnosticInfo) error { + createSSODiagnosticInfoStub := func(ctx context.Context, authKind string, authRequestID string, entry types.SSODiagnosticInfo) error { ssoDiagInfoCalls++ return nil } - // TestFlow: false - m.testFlow = false + ssoDiagContextFixture := func(testFlow bool) *SSODiagContext { + diagCtx := NewSSODiagContext(types.KindGithub, SSODiagServiceFunc(createSSODiagnosticInfoStub)) + diagCtx.RequestID = uuid.New().String() + diagCtx.Info.TestFlow = testFlow + return diagCtx + } + m := &mockedGithubManager{} - // Test success event. - m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { - diagCtx.info.GithubClaims = claims + // Test success event, non-test-flow. + diagCtx := ssoDiagContextFixture(false /* testFlow */) + m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { + diagCtx.Info.GithubClaims = claims return auth, nil } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, diagCtx, nil, tt.a.emitter) require.Equal(t, tt.mockEmitter.LastEvent().GetType(), events.UserLoginEvent) require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOLoginCode) require.Equal(t, ssoDiagInfoCalls, 0) tt.mockEmitter.Reset() - // Test failure event. - m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { - diagCtx.info.GithubClaims = claims + // Test failure event, non-test-flow. + diagCtx = ssoDiagContextFixture(false /* testFlow */) + m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { + diagCtx.Info.GithubClaims = claims return auth, trace.BadParameter("") } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, diagCtx, nil, tt.a.emitter) require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOLoginFailureCode) require.Equal(t, ssoDiagInfoCalls, 0) - // TestFlow: true - m.testFlow = true - - // Test success event. - m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { - diagCtx.info.GithubClaims = claims + // Test success event, test-flow. + diagCtx = ssoDiagContextFixture(true /* testFlow */) + m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { + diagCtx.Info.GithubClaims = claims return auth, nil } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, diagCtx, nil, tt.a.emitter) require.Equal(t, tt.mockEmitter.LastEvent().GetType(), events.UserLoginEvent) require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOTestFlowLoginCode) require.Equal(t, ssoDiagInfoCalls, 1) tt.mockEmitter.Reset() - // Test failure event. - m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { - diagCtx.info.GithubClaims = claims + // Test failure event, test-flow. + diagCtx = ssoDiagContextFixture(true /* testFlow */) + m.mockValidateGithubAuthCallback = func(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { + diagCtx.Info.GithubClaims = claims return auth, trace.BadParameter("") } - _, _ = validateGithubAuthCallbackHelper(context.Background(), m, nil, tt.a.emitter) + _, _ = validateGithubAuthCallbackHelper(context.Background(), m, diagCtx, nil, tt.a.emitter) require.Equal(t, tt.mockEmitter.LastEvent().GetCode(), events.UserSSOTestFlowLoginFailureCode) require.Equal(t, ssoDiagInfoCalls, 2) } type mockedGithubManager struct { - mockValidateGithubAuthCallback func(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) - createSSODiagnosticInfo func(ctx context.Context, authKind string, authRequestID string, info types.SSODiagnosticInfo) error - - testFlow bool -} - -func (m *mockedGithubManager) newSSODiagContext(authKind string) *ssoDiagContext { - if m.createSSODiagnosticInfo == nil { - panic("mockedGithubManager.createSSODiagnosticInfo is nil, newSSODiagContext cannot succeed.") - } - - return &ssoDiagContext{ - authKind: authKind, - createSSODiagnosticInfo: m.createSSODiagnosticInfo, - requestID: uuid.New().String(), - info: types.SSODiagnosticInfo{TestFlow: m.testFlow}, - } + mockValidateGithubAuthCallback func(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) } -func (m *mockedGithubManager) validateGithubAuthCallback(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*GithubAuthResponse, error) { +func (m *mockedGithubManager) validateGithubAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*GithubAuthResponse, error) { if m.mockValidateGithubAuthCallback != nil { return m.mockValidateGithubAuthCallback(ctx, diagCtx, q) } diff --git a/lib/auth/helpers.go b/lib/auth/helpers.go index 1dd1418144f60..ffe7a7aa6b009 100644 --- a/lib/auth/helpers.go +++ b/lib/auth/helpers.go @@ -121,17 +121,35 @@ func NewTestServer(cfg TestServerConfig) (*TestServer, error) { if err != nil { return nil, trace.Wrap(err) } - var tlsServer *TestTLSServer - if cfg.TLS != nil { - tlsServer, err = NewTestTLSServer(*cfg.TLS) - if err != nil { - return nil, trace.Wrap(err) - } - } else { - tlsServer, err = authServer.NewTestTLSServer() - if err != nil { - return nil, trace.Wrap(err) - } + // Set the (test) auth server in cfg.TLS and set any defaults that + // are not set. + tlsCfg := cfg.TLS + if tlsCfg == nil { + tlsCfg = &TestTLSServerConfig{} + } + if tlsCfg.APIConfig == nil { + tlsCfg.APIConfig = &APIConfig{} + } + + tlsCfg.AuthServer = authServer + tlsCfg.APIConfig.AuthServer = authServer.AuthServer + + if tlsCfg.APIConfig.Authorizer == nil { + tlsCfg.APIConfig.Authorizer = authServer.Authorizer + } + if tlsCfg.APIConfig.AuditLog == nil { + tlsCfg.APIConfig.AuditLog = authServer.AuditLog + } + if tlsCfg.APIConfig.Emitter == nil { + tlsCfg.APIConfig.Emitter = authServer.AuthServer.emitter + } + if tlsCfg.AcceptedUsage == nil { + tlsCfg.AcceptedUsage = authServer.AcceptedUsage + } + + tlsServer, err := NewTestTLSServer(*tlsCfg) + if err != nil { + return nil, trace.Wrap(err) } return &TestServer{ AuthServer: authServer, diff --git a/lib/auth/methods.go b/lib/auth/methods.go index d9055fe58625f..b7e979e4d5916 100644 --- a/lib/auth/methods.go +++ b/lib/auth/methods.go @@ -553,7 +553,7 @@ func (s *Server) emitNoLocalAuthEvent(username string) { func (s *Server) createUserWebSession(ctx context.Context, user types.User) (types.WebSession, error) { // It's safe to extract the roles and traits directly from services.User as this method // is only used for local accounts. - return s.createWebSession(ctx, types.NewWebSessionRequest{ + return s.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user.GetName(), Roles: user.GetRoles(), Traits: user.GetTraits(), diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index ff7e621b8b39a..00fb100cecbc2 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -23,6 +23,7 @@ import ( "io" "net/http" "net/url" + "sync" "time" "github.com/coreos/go-oidc/jose" @@ -44,12 +45,148 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +type OIDCService interface { + CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) + ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) +} + +var errOIDCNotImplemented = trace.AccessDenied("OIDC is only available in enterprise subscriptions") + +func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { + if a.oidcAuthService == nil { + return nil, errOIDCNotImplemented + } + + rq, err := a.oidcAuthService.CreateOIDCAuthRequest(ctx, req) + return rq, trace.Wrap(err) +} + +func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { + if a.oidcAuthService == nil { + return nil, errOIDCNotImplemented + } + + resp, err := a.oidcAuthService.ValidateOIDCAuthCallback(ctx, q) + return resp, trace.Wrap(err) +} + +// OIDCAuthResponse is returned when auth server validated callback parameters +// returned from OIDC provider +type OIDCAuthResponse struct { + // Username is authenticated teleport username + Username string `json:"username"` + // Identity contains validated OIDC identity + Identity types.ExternalIdentity `json:"identity"` + // Web session will be generated by auth server if requested in OIDCAuthRequest + Session types.WebSession `json:"session,omitempty"` + // Cert will be generated by certificate authority + Cert []byte `json:"cert,omitempty"` + // TLSCert is PEM encoded TLS certificate + TLSCert []byte `json:"tls_cert,omitempty"` + // Req is original oidc auth request + Req OIDCAuthRequest `json:"req"` + // HostSigners is a list of signing host public keys + // trusted by proxy, used in console login + HostSigners []types.CertAuthority `json:"host_signers"` +} + +// OIDCAuthRequest is an OIDC auth request that supports standard json marshaling. +type OIDCAuthRequest struct { + // ConnectorID is ID of OIDC connector this request uses + ConnectorID string `json:"connector_id"` + // CSRFToken is associated with user web session token + CSRFToken string `json:"csrf_token"` + // PublicKey is an optional public key, users want these + // keys to be signed by auth servers user CA in case + // of successful auth + PublicKey []byte `json:"public_key"` + // CreateWebSession indicates if user wants to generate a web + // session after successful authentication + CreateWebSession bool `json:"create_web_session"` + // ClientRedirectURL is a URL client wants to be redirected + // after successful authentication + ClientRedirectURL string `json:"client_redirect_url"` +} + +// ValidateOIDCAuthCallbackReq is the request made by the proxy to validate +// and activate a login via OIDC. +type ValidateOIDCAuthCallbackReq struct { + Query url.Values `json:"query"` +} + +// OIDCAuthRawResponse is returned when auth server validated callback parameters +// returned from OIDC provider +type OIDCAuthRawResponse struct { + // Username is authenticated teleport username + Username string `json:"username"` + // Identity contains validated OIDC identity + Identity types.ExternalIdentity `json:"identity"` + // Web session will be generated by auth server if requested in OIDCAuthRequest + Session json.RawMessage `json:"session,omitempty"` + // Cert will be generated by certificate authority + Cert []byte `json:"cert,omitempty"` + // TLSCert is PEM encoded TLS certificate + TLSCert []byte `json:"tls_cert,omitempty"` + // Req is original oidc auth request + Req OIDCAuthRequest `json:"req"` + // HostSigners is a list of signing host public keys + // trusted by proxy, used in console login + HostSigners []json.RawMessage `json:"host_signers"` +} + +type OIDCAuthService struct { + auth *Server + emitter apievents.Emitter + clients map[string]*oidcClient + lock sync.Mutex + getClaimsFun func(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) +} + +type OIDCAuthServiceConfig struct { + Auth *Server + Emitter apievents.Emitter +} + +func (cfg *OIDCAuthServiceConfig) CheckAndSetDefaults() error { + if cfg.Auth == nil { + return trace.BadParameter("auth.Server not provided") + } + if cfg.Emitter == nil { + cfg.Emitter = events.NewDiscardEmitter() + } + return nil +} + +func NewOIDCAuthService(cfg *OIDCAuthServiceConfig) (*OIDCAuthService, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, err + } + + return &OIDCAuthService{ + auth: cfg.Auth, + emitter: cfg.Emitter, + clients: make(map[string]*oidcClient), + getClaimsFun: getClaims, + }, nil +} + +// oidcClient is internal structure that stores OIDC client and its config +type oidcClient struct { + client *oidc.Client + connector types.OIDCConnector + // syncCtx controls the provider sync goroutine. + syncCtx context.Context + syncCancel context.CancelFunc + // firstSync will be closed once the first provider sync succeeds + firstSync chan struct{} +} + // ErrOIDCNoRoles results from not mapping any roles from OIDC claims. var ErrOIDCNoRoles = trace.AccessDenied("No roles mapped from claims. The mappings may contain typos.") // getOIDCConnectorAndClient returns the associated oidc connector // and client for the given oidc auth request. -func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OIDCAuthRequest) (types.OIDCConnector, *oidc.Client, error) { +func (oas *OIDCAuthService) getOIDCConnectorAndClient(ctx context.Context, request types.OIDCAuthRequest) (types.OIDCConnector, *oidc.Client, error) { // stateless test flow if request.SSOTestFlow { if request.ConnectorSpec == nil { @@ -76,7 +213,7 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI // close this request-scoped oidc client after 10 minutes go func() { - ticker := a.GetClock().NewTicker(defaults.OIDCAuthRequestTTL) + ticker := oas.auth.GetClock().NewTicker(defaults.OIDCAuthRequestTTL) defer ticker.Stop() select { case <-ticker.Chan(): @@ -89,12 +226,12 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI } // regular execution flow - connector, err := a.GetOIDCConnector(ctx, request.ConnectorID, true) + connector, err := oas.auth.GetOIDCConnector(ctx, request.ConnectorID, true) if err != nil { return nil, nil, trace.Wrap(err) } - client, err := a.getCachedOIDCClient(ctx, connector, request.ProxyAddress) + client, err := oas.getCachedOIDCClient(ctx, connector, request.ProxyAddress) if err != nil { return nil, nil, trace.Wrap(err) } @@ -110,22 +247,22 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI // getCachedOIDCClient gets a cached oidc client for // the given OIDC connector and redirectURL preference. -func (a *Server) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnector, proxyAddr string) (*oidcClient, error) { - a.lock.Lock() - defer a.lock.Unlock() +func (oas *OIDCAuthService) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnector, proxyAddr string) (*oidcClient, error) { + oas.lock.Lock() + defer oas.lock.Unlock() // Each connector and proxy combination has a distinct client, // so we use a composite key to capture all combinations. clientMapKey := conn.GetName() + "_" + proxyAddr - cachedClient, ok := a.oidcClients[clientMapKey] + cachedClient, ok := oas.clients[clientMapKey] if ok { if !cachedClient.needsRefresh(conn) && cachedClient.syncCtx.Err() == nil { return cachedClient, nil } // Cached client needs to be refreshed or is no longer syncing. cachedClient.syncCancel() - delete(a.oidcClients, clientMapKey) + delete(oas.clients, clientMapKey) } // Create a new oidc client and add it to the cache. @@ -134,7 +271,7 @@ func (a *Server) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnect return nil, trace.Wrap(err) } - a.oidcClients[clientMapKey] = client + oas.clients[clientMapKey] = client return client, nil } @@ -259,12 +396,12 @@ func (a *Server) DeleteOIDCConnector(ctx context.Context, connectorName string) return nil } -func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { +func (oas *OIDCAuthService) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { // ensure prompt removal of OIDC client in test flows. does nothing in regular flows. ctx, cancel := context.WithCancel(ctx) defer cancel() - connector, client, err := a.getOIDCConnectorAndClient(ctx, req) + connector, client, err := oas.getOIDCConnectorAndClient(ctx, req) if err != nil { return nil, trace.Wrap(err) } @@ -299,7 +436,7 @@ func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRe log.Debugf("OIDC redirect URL: %v.", req.RedirectURL) - err = a.Services.CreateOIDCAuthRequest(ctx, req, defaults.OIDCAuthRequestTTL) + err = oas.auth.Services.CreateOIDCAuthRequest(ctx, req, defaults.OIDCAuthRequestTTL) if err != nil { return nil, trace.Wrap(err) } @@ -309,7 +446,7 @@ func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRe // ValidateOIDCAuthCallback is called by the proxy to check OIDC query parameters // returned by OIDC Provider, if everything checks out, auth server // will respond with OIDCAuthResponse, otherwise it will return error -func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { +func (oas *OIDCAuthService) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { event := &apievents.UserLogin{ Metadata: apievents.Metadata{ Type: events.UserLoginEvent, @@ -317,14 +454,14 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O Method: events.LoginMethodOIDC, } - diagCtx := a.newSSODiagContext(types.KindOIDC) + diagCtx := NewSSODiagContext(types.KindOIDC, oas.auth) - auth, err := a.validateOIDCAuthCallback(ctx, diagCtx, q) - diagCtx.info.Error = trace.UserMessage(err) + auth, err := oas.validateOIDCAuthCallback(ctx, diagCtx, q) + diagCtx.Info.Error = trace.UserMessage(err) - diagCtx.writeToBackend(ctx) + diagCtx.WriteToBackend(ctx) - claims := diagCtx.info.OIDCClaims + claims := diagCtx.Info.OIDCClaims if claims != nil { attributes, err := apievents.EncodeMap(claims) if err != nil { @@ -337,14 +474,14 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O if err != nil { event.Code = events.UserSSOLoginFailureCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginFailureCode } event.Status.Success = false event.Status.Error = trace.Unwrap(err).Error() event.Status.UserMessage = err.Error() - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := oas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit OIDC login failed event.") } @@ -352,13 +489,13 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O } event.Code = events.UserSSOLoginCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginCode } event.User = auth.Username event.Status.Success = true - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := oas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit OIDC login event.") } @@ -398,15 +535,15 @@ func checkEmailVerifiedClaim(claims jose.Claims) error { return nil } -func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagContext, q url.Values) (*OIDCAuthResponse, error) { +func (oas *OIDCAuthService) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*OIDCAuthResponse, error) { if errParam := q.Get("error"); errParam != "" { // try to find request so the error gets logged against it. state := q.Get("state") if state != "" { - diagCtx.requestID = state - req, err := a.GetOIDCAuthRequest(ctx, state) + diagCtx.RequestID = state + req, err := oas.auth.GetOIDCAuthRequest(ctx, state) if err == nil { - diagCtx.info.TestFlow = req.SSOTestFlow + diagCtx.Info.TestFlow = req.SSOTestFlow } } @@ -427,25 +564,25 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagC oidcErr := trace.OAuth2(oauth2.ErrorInvalidRequest, "missing state query param", q) return nil, trace.WithUserMessage(oidcErr, "Invalid parameters received from OIDC provider.") } - diagCtx.requestID = stateToken + diagCtx.RequestID = stateToken - req, err := a.GetOIDCAuthRequest(ctx, stateToken) + req, err := oas.auth.GetOIDCAuthRequest(ctx, stateToken) if err != nil { return nil, trace.Wrap(err, "Failed to get OIDC Auth Request.") } - diagCtx.info.TestFlow = req.SSOTestFlow + diagCtx.Info.TestFlow = req.SSOTestFlow // ensure prompt removal of OIDC client in test flows. does nothing in regular flows. ctxC, cancel := context.WithCancel(ctx) defer cancel() - connector, client, err := a.getOIDCConnectorAndClient(ctxC, *req) + connector, client, err := oas.getOIDCConnectorAndClient(ctxC, *req) if err != nil { return nil, trace.Wrap(err, "Failed to get OIDC connector and client.") } // extract claims from both the id token and the userinfo endpoint and merge them - claims, err := a.getClaims(client, connector, code) + claims, err := oas.getClaims(ctx, client, connector, code) if err != nil { // different error message for Google Workspace as likely cause is different. if isGoogleWorkspaceConnector(connector) { @@ -454,7 +591,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagC return nil, trace.Wrap(err, "Failed to extract OIDC claims. This may indicate need to set 'provider' flag in connector definition. See: https://goteleport.com/docs/enterprise/sso/#provider-specific-workarounds") } - diagCtx.info.OIDCClaims = types.OIDCClaims(claims) + diagCtx.Info.OIDCClaims = types.OIDCClaims(claims) log.Debugf("OIDC claims: %v.", claims) if !connector.GetAllowUnverifiedEmail() { @@ -466,7 +603,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagC // if we are sending acr values, make sure we also validate them acrValue := connector.GetACR() if acrValue != "" { - err := a.validateACRValues(acrValue, connector.GetProvider(), claims) + err := validateACRValues(acrValue, connector.GetProvider(), claims) if err != nil { return nil, trace.Wrap(err, "OIDC ACR validation failure.") } @@ -478,7 +615,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagC return nil, trace.OAuth2( oauth2.ErrorUnsupportedResponseType, "unable to convert claims to identity", q) } - diagCtx.info.OIDCIdentity = &types.OIDCIdentity{ + diagCtx.Info.OIDCIdentity = &types.OIDCIdentity{ ID: ident.ID, Name: ident.Name, Email: ident.Email, @@ -491,130 +628,92 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *ssoDiagC return nil, trace.WithUserMessage(oidcErr, "Claims-to-roles mapping is empty, SSO user will never have any roles.") } log.Debugf("Applying %v OIDC claims to roles mappings.", len(connector.GetClaimsToRoles())) - diagCtx.info.OIDCClaimsToRoles = connector.GetClaimsToRoles() + diagCtx.Info.OIDCClaimsToRoles = connector.GetClaimsToRoles() // Calculate (figure out name, roles, traits, session TTL) of user and // create the user in the backend. - params, err := a.calculateOIDCUser(diagCtx, connector, claims, ident, req) + params, err := oas.calculateOIDCUser(diagCtx, connector, claims, ident, req) if err != nil { return nil, trace.Wrap(err, "Failed to calculate user attributes.") } - diagCtx.info.CreateUserParams = &types.CreateUserParams{ - ConnectorName: params.connectorName, - Username: params.username, - KubeGroups: params.kubeGroups, - KubeUsers: params.kubeUsers, - Roles: params.roles, - Traits: params.traits, - SessionTTL: types.Duration(params.sessionTTL), + diagCtx.Info.CreateUserParams = &types.CreateUserParams{ + ConnectorName: params.ConnectorName, + Username: params.Username, + KubeGroups: params.KubeGroups, + KubeUsers: params.KubeUsers, + Roles: params.Roles, + Traits: params.Traits, + SessionTTL: types.Duration(params.SessionTTL), } - user, err := a.createOIDCUser(params, req.SSOTestFlow) + user, err := oas.createOIDCUser(params, req.SSOTestFlow) if err != nil { return nil, trace.Wrap(err, "Failed to create user from provided parameters.") } // Auth was successful, return session, certificate, etc. to caller. - auth := &OIDCAuthResponse{ + resp := &OIDCAuthResponse{ Req: OIDCAuthRequestFromProto(req), Identity: types.ExternalIdentity{ - ConnectorID: params.connectorName, - Username: params.username, + ConnectorID: params.ConnectorName, + Username: params.Username, }, Username: user.GetName(), } // In test flow skip signing and creating web sessions. if req.SSOTestFlow { - diagCtx.info.Success = true - return auth, nil + diagCtx.Info.Success = true + return resp, nil } if !req.CheckUser { - return auth, nil + return resp, nil } // If the request is coming from a browser, create a web session. if req.CreateWebSession { - session, err := a.createWebSession(ctx, types.NewWebSessionRequest{ + session, err := oas.auth.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user.GetName(), Roles: user.GetRoles(), Traits: user.GetTraits(), - SessionTTL: params.sessionTTL, - LoginTime: a.clock.Now().UTC(), + SessionTTL: params.SessionTTL, + LoginTime: oas.auth.GetClock().Now().UTC(), }) if err != nil { return nil, trace.Wrap(err, "Failed to create web session.") } - auth.Session = session + resp.Session = session } // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, + sshCert, tlsCert, err := oas.auth.CreateSessionCert(user, params.SessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, req.KubernetesCluster, keys.AttestationStatementFromProto(req.AttestationStatement)) if err != nil { return nil, trace.Wrap(err, "Failed to create session certificate.") } - clusterName, err := a.GetClusterName() + clusterName, err := oas.auth.GetClusterName() if err != nil { return nil, trace.Wrap(err, "Failed to obtain cluster name.") } - auth.Cert = sshCert - auth.TLSCert = tlsCert + resp.Cert = sshCert + resp.TLSCert = tlsCert // Return the host CA for this cluster only. - authority, err := a.GetCertAuthority(ctx, types.CertAuthID{ + authority, err := oas.auth.GetCertAuthority(ctx, types.CertAuthID{ Type: types.HostCA, DomainName: clusterName.GetClusterName(), }, false) if err != nil { return nil, trace.Wrap(err, "Failed to obtain cluster's host CA.") } - auth.HostSigners = append(auth.HostSigners, authority) + resp.HostSigners = append(resp.HostSigners, authority) } - return auth, nil -} - -// OIDCAuthResponse is returned when auth server validated callback parameters -// returned from OIDC provider -type OIDCAuthResponse struct { - // Username is authenticated teleport username - Username string `json:"username"` - // Identity contains validated OIDC identity - Identity types.ExternalIdentity `json:"identity"` - // Web session will be generated by auth server if requested in OIDCAuthRequest - Session types.WebSession `json:"session,omitempty"` - // Cert will be generated by certificate authority - Cert []byte `json:"cert,omitempty"` - // TLSCert is PEM encoded TLS certificate - TLSCert []byte `json:"tls_cert,omitempty"` - // Req is original oidc auth request - Req OIDCAuthRequest `json:"req"` - // HostSigners is a list of signing host public keys - // trusted by proxy, used in console login - HostSigners []types.CertAuthority `json:"host_signers"` -} - -// OIDCAuthRequest is an OIDC auth request that supports standard json marshaling. -type OIDCAuthRequest struct { - // ConnectorID is ID of OIDC connector this request uses - ConnectorID string `json:"connector_id"` - // CSRFToken is associated with user web session token - CSRFToken string `json:"csrf_token"` - // PublicKey is an optional public key, users want these - // keys to be signed by auth servers user CA in case - // of successful auth - PublicKey []byte `json:"public_key"` - // CreateWebSession indicates if user wants to generate a web - // session after successful authentication - CreateWebSession bool `json:"create_web_session"` - // ClientRedirectURL is a URL client wants to be redirected - // after successful authentication - ClientRedirectURL string `json:"client_redirect_url"` + return resp, nil } // OIDCAuthRequestFromProto converts the types.OIDCAuthRequest to OIDCAuthRequest. @@ -628,7 +727,7 @@ func OIDCAuthRequestFromProto(req *types.OIDCAuthRequest) OIDCAuthRequest { } } -func (a *Server) calculateOIDCUser(diagCtx *ssoDiagContext, connector types.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *types.OIDCAuthRequest) (*createUserParams, error) { +func (oas *OIDCAuthService) calculateOIDCUser(diagCtx *SSODiagContext, connector types.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *types.OIDCAuthRequest) (*CreateUserParams, error) { var err error username, err := usernameFromClaims(connector, claims, ident) @@ -636,28 +735,28 @@ func (a *Server) calculateOIDCUser(diagCtx *ssoDiagContext, connector types.OIDC return nil, err } - p := createUserParams{ - connectorName: connector.GetName(), - username: username, + p := CreateUserParams{ + ConnectorName: connector.GetName(), + Username: username, } - p.traits = services.OIDCClaimsToTraits(claims) + p.Traits = services.OIDCClaimsToTraits(claims) - diagCtx.info.OIDCTraitsFromClaims = p.traits - diagCtx.info.OIDCConnectorTraitMapping = connector.GetTraitMappings() + diagCtx.Info.OIDCTraitsFromClaims = p.Traits + diagCtx.Info.OIDCConnectorTraitMapping = connector.GetTraitMappings() var warnings []string - warnings, p.roles = services.TraitsToRoles(connector.GetTraitMappings(), p.traits) - if len(p.roles) == 0 { + warnings, p.Roles = services.TraitsToRoles(connector.GetTraitMappings(), p.Traits) + if len(p.Roles) == 0 { if len(warnings) != 0 { log.WithField("connector", connector).Warnf("No roles mapped from claims. Warnings: %q", warnings) - diagCtx.info.OIDCClaimsToRolesWarnings = &types.SSOWarnings{ + diagCtx.Info.OIDCClaimsToRolesWarnings = &types.SSOWarnings{ Message: "No roles mapped for the user", Warnings: warnings, } } else { log.WithField("connector", connector).Warnf("No roles mapped from claims.") - diagCtx.info.OIDCClaimsToRolesWarnings = &types.SSOWarnings{ + diagCtx.Info.OIDCClaimsToRolesWarnings = &types.SSOWarnings{ Message: "No roles mapped for the user. The mappings may contain typos.", } } @@ -665,44 +764,44 @@ func (a *Server) calculateOIDCUser(diagCtx *ssoDiagContext, connector types.OIDC } // Pick smaller for role: session TTL from role or requested TTL. - roles, err := services.FetchRoles(p.roles, a, p.traits) + roles, err := services.FetchRoles(p.Roles, oas.auth, p.Traits) if err != nil { return nil, trace.Wrap(err) } roleTTL := roles.AdjustSessionTTL(apidefaults.MaxCertDuration) - p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + p.SessionTTL = utils.MinTTL(roleTTL, request.CertTTL) return &p, nil } -func (a *Server) createOIDCUser(p *createUserParams, dryRun bool) (types.User, error) { - expires := a.GetClock().Now().UTC().Add(p.sessionTTL) +func (oas *OIDCAuthService) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, error) { + expires := oas.auth.GetClock().Now().UTC().Add(p.SessionTTL) - log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v. Dry run: %v.", p.connectorName, p.username, p.roles, dryRun) + log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v. Dry run: %v.", p.ConnectorName, p.Username, p.Roles, dryRun) user := &types.UserV2{ Kind: types.KindUser, Version: types.V2, Metadata: types.Metadata{ - Name: p.username, + Name: p.Username, Namespace: apidefaults.Namespace, Expires: &expires, }, Spec: types.UserSpecV2{ - Roles: p.roles, - Traits: p.traits, + Roles: p.Roles, + Traits: p.Traits, OIDCIdentities: []types.ExternalIdentity{ { - ConnectorID: p.connectorName, - Username: p.username, + ConnectorID: p.ConnectorName, + Username: p.Username, }, }, CreatedBy: types.CreatedBy{ User: types.UserRef{Name: teleport.UserSystem}, - Time: a.clock.Now().UTC(), + Time: oas.auth.GetClock().Now().UTC(), Connector: &types.ConnectorRef{ Type: constants.OIDC, - ID: p.connectorName, - Identity: p.username, + ID: p.ConnectorName, + Identity: p.Username, }, }, }, @@ -713,7 +812,7 @@ func (a *Server) createOIDCUser(p *createUserParams, dryRun bool) (types.User, e } // Get the user to check if it already exists or not. - existingUser, err := a.Services.GetUser(p.username, false) + existingUser, err := oas.auth.Services.GetUser(p.Username, false) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) } @@ -733,11 +832,11 @@ func (a *Server) createOIDCUser(p *createUserParams, dryRun bool) (types.User, e log.Debugf("Overwriting existing user %q created with %v connector %v.", existingUser.GetName(), connectorRef.Type, connectorRef.ID) - if err := a.UpdateUser(ctx, user); err != nil { + if err := oas.auth.UpdateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } else { - if err := a.CreateUser(ctx, user); err != nil { + if err := oas.auth.CreateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } @@ -875,12 +974,12 @@ func mergeClaims(a jose.Claims, b jose.Claims) (jose.Claims, error) { } // getClaims gets claims from ID token and UserInfo and returns UserInfo claims merged into ID token claims. -func (a *Server) getClaims(oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { - return a.getClaimsFun(a.closeCtx, oidcClient, connector, code) +func (oas *OIDCAuthService) getClaims(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { + return oas.getClaimsFun(ctx, oidcClient, connector, code) } -// getClaims implements Server.getClaims, but allows that code path to be overridden for testing. -func getClaims(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { +// getClaims implements OIDCAuthService.getClaims, but allows that code path to be overridden for testing. +func getClaims(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { oac, err := getOAuthClient(oidcClient, connector) if err != nil { return nil, trace.Wrap(err) @@ -945,7 +1044,7 @@ func getClaims(closeCtx context.Context, oidcClient *oidc.Client, connector type } if isGoogleWorkspaceConnector(connector) { - claims, err = addGoogleWorkspaceClaims(closeCtx, connector, claims) + claims, err = addGoogleWorkspaceClaims(ctx, connector, claims) if err != nil { return nil, trace.Wrap(err) } @@ -975,7 +1074,7 @@ func getOAuthClient(oidcClient *oidc.Client, connector types.OIDCConnector) (*oa // validateACRValues validates that we get an appropriate response for acr values. By default // we expect the same value we send, but this function also handles Identity Provider specific // forms of validation. -func (a *Server) validateACRValues(acrValue string, identityProvider string, claims jose.Claims) error { +func validateACRValues(acrValue string, identityProvider string, claims jose.Claims) error { switch identityProvider { case teleport.NetIQ: log.Debugf("Validating OIDC ACR values with '%v' rules.", identityProvider) @@ -1030,3 +1129,16 @@ func (a *Server) validateACRValues(acrValue string, identityProvider string, cla return nil } + +// isHTTPS checks if the scheme for a URL is https or not. +func isHTTPS(u string) error { + earl, err := url.Parse(u) + if err != nil { + return trace.Wrap(err) + } + if earl.Scheme != "https" { + return trace.BadParameter("expected scheme https, got %q", earl.Scheme) + } + + return nil +} diff --git a/lib/auth/oidc_test.go b/lib/auth/oidc_test.go index 18e31c8c5bafb..78a96be4e4e70 100644 --- a/lib/auth/oidc_test.go +++ b/lib/auth/oidc_test.go @@ -51,9 +51,10 @@ import ( ) type OIDCSuite struct { - a *Server - b backend.Backend - c clockwork.FakeClock + a *Server + b backend.Backend + c clockwork.FakeClock + oas *OIDCAuthService } func setUpSuite(t *testing.T) *OIDCSuite { @@ -87,6 +88,11 @@ func setUpSuite(t *testing.T) *OIDCSuite { } s.a, err = NewServer(authConfig) require.NoError(t, err) + + var ok bool + s.oas, ok = s.a.oidcAuthService.(*OIDCAuthService) + require.True(t, ok, "Server.oidcAuthService is not type *OIDCAuthService") + return &s } @@ -112,11 +118,11 @@ func TestCreateOIDCUser(t *testing.T) { s := setUpSuite(t) // Dry-run creation of OIDC user. - user, err := s.a.createOIDCUser(&createUserParams{ - connectorName: "oidcService", - username: "foo@example.com", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + user, err := s.oas.createOIDCUser(&CreateUserParams{ + ConnectorName: "oidcService", + Username: "foo@example.com", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, true) require.NoError(t, err) require.Equal(t, "foo@example.com", user.GetName()) @@ -126,11 +132,11 @@ func TestCreateOIDCUser(t *testing.T) { require.Error(t, err) // Create OIDC user with 1 minute expiry. - _, err = s.a.createOIDCUser(&createUserParams{ - connectorName: "oidcService", - username: "foo@example.com", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + _, err = s.oas.createOIDCUser(&CreateUserParams{ + ConnectorName: "oidcService", + Username: "foo@example.com", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, false) require.NoError(t, err) @@ -153,6 +159,7 @@ func TestUserInfoBlockHTTP(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -166,7 +173,7 @@ func TestUserInfoBlockHTTP(t *testing.T) { }) require.NoError(t, err) - oidcClient, err := s.a.getCachedOIDCClient(ctx, connector, "") + oidcClient, err := s.oas.getCachedOIDCClient(ctx, connector, "") require.NoError(t, err) // Verify HTTP endpoints return trace.NotFound. @@ -232,6 +239,7 @@ func TestSSODiagnostic(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -274,7 +282,7 @@ func TestSSODiagnostic(t *testing.T) { } // override getClaimsFun. - s.a.getClaimsFun = func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { + s.oas.getClaimsFun = func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { cc := map[string]interface{}{ "email_verified": true, "groups": []string{"everyone", "idp-admin", "idp-dev"}, @@ -285,7 +293,7 @@ func TestSSODiagnostic(t *testing.T) { return cc, nil } - resp, err := s.a.ValidateOIDCAuthCallback(ctx, values) + resp, err := s.oas.ValidateOIDCAuthCallback(ctx, values) if tc.wantValidateErr != nil { require.ErrorIs(t, err, tc.wantValidateErr) return @@ -302,9 +310,9 @@ func TestSSODiagnostic(t *testing.T) { Req: OIDCAuthRequestFromProto(request), }, resp) - diagCtx := ssoDiagContext{} + diagCtx := SSODiagContext{} - resp, err = s.a.validateOIDCAuthCallback(ctx, &diagCtx, values) + resp, err = s.oas.validateOIDCAuthCallback(ctx, &diagCtx, values) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, &OIDCAuthResponse{ @@ -351,7 +359,7 @@ func TestSSODiagnostic(t *testing.T) { ID: "00001234abcd", Name: "", Email: "superuser@example.com", - ExpiresAt: diagCtx.info.OIDCIdentity.ExpiresAt, + ExpiresAt: diagCtx.Info.OIDCIdentity.ExpiresAt, }, OIDCTraitsFromClaims: map[string][]string{ "email": {"superuser@example.com"}, @@ -365,7 +373,7 @@ func TestSSODiagnostic(t *testing.T) { Roles: []string{"access"}, }, }, - }, diagCtx.info) + }, diagCtx.Info) }) } } @@ -377,6 +385,7 @@ func TestPingProvider(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -410,7 +419,7 @@ func TestPingProvider(t *testing.T) { }, } { t.Run(fmt.Sprintf("Test SSOFlow: %v", req.SSOTestFlow), func(t *testing.T) { - oidcConnector, oidcClient, err := s.a.getOIDCConnectorAndClient(ctx, req) + oidcConnector, oidcClient, err := s.oas.getOIDCConnectorAndClient(ctx, req) require.NoError(t, err) oac, err := getOAuthClient(oidcClient, oidcConnector) @@ -487,6 +496,7 @@ func TestOIDCClientCache(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) connectorSpec := types.OIDCConnectorSpecV3{ @@ -501,17 +511,17 @@ func TestOIDCClientCache(t *testing.T) { require.NoError(t, err) // Create and cache a new oidc client - client, err := s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + client, err := s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) // The next call should return the same client (compare memory address) - cachedClient, err := s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + cachedClient, err := s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) require.True(t, client == cachedClient) // Canceling provider sync on a cached client should cause it to be replaced client.syncCancel() - cachedClient, err = s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + cachedClient, err = s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) require.False(t, client == cachedClient) @@ -560,12 +570,12 @@ func TestOIDCClientCache(t *testing.T) { require.NoError(t, err) tc.mutateConnector(newConnector) - client, err = s.a.getCachedOIDCClient(ctx, newConnector, "proxy.example.com") + client, err = s.oas.getCachedOIDCClient(ctx, newConnector, "proxy.example.com") require.NoError(t, err) require.True(t, (client == originalClient) == tc.expectNoRefresh) // reset cached client to the original client for remaining tests - originalClient, err = s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + originalClient, err = s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) }) } @@ -795,7 +805,7 @@ func TestUsernameClaim(t *testing.T) { s := setUpSuite(t) idp := newFakeIDP(t, false) - diagCtx := ssoDiagContext{} + diagCtx := SSODiagContext{} // Create role that will be mapped to the user. role, err := types.NewRole("access", types.RoleSpecV5{ @@ -880,13 +890,112 @@ func TestUsernameClaim(t *testing.T) { require.NoError(t, err) // Generate the userCreateParams for the OIDC user. - createUserParams, err := s.a.calculateOIDCUser(&diagCtx, connector, claims, ident, request) + createUserParams, err := s.oas.calculateOIDCUser(&diagCtx, connector, claims, ident, request) if tc.expectedError != "" { require.ErrorContains(t, err, tc.expectedError) } else { require.NoError(t, err) - require.Equal(t, tc.expectedUsername, createUserParams.username) + require.Equal(t, tc.expectedUsername, createUserParams.Username) } }) } } + +func TestValidateACRValues(t *testing.T) { + tests := []struct { + comment string + inIDToken string + inACRValue string + inACRProvider string + outIsValid require.ErrorAssertionFunc + }{ + { + "0 - default, acr values match", + ` +{ + "acr": "foo", + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo", + "", + require.NoError, + }, + { + "1 - default, acr values do not match", + ` +{ + "acr": "foo", + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "bar", + "", + require.Error, + }, + { + "2 - netiq, acr values match", + ` +{ + "acr": { + "values": [ + "foo/bar/baz" + ] + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.NoError, + }, + { + "3 - netiq, invalid format", + ` +{ + "acr": { + "values": "foo/bar/baz" + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.Error, + }, + { + "4 - netiq, invalid value", + ` +{ + "acr": { + "values": [ + "foo/bar/baz/qux" + ] + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.Error, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.comment, func(t *testing.T) { + t.Parallel() + var claims jose.Claims + err := json.Unmarshal([]byte(tt.inIDToken), &claims) + require.NoError(t, err) + + err = validateACRValues(tt.inACRValue, tt.inACRProvider, claims) + tt.outIsValid(t, err) + }) + } +} diff --git a/lib/auth/saml.go b/lib/auth/saml.go index b4aa374aefe09..a8bfbf2b3f4b7 100644 --- a/lib/auth/saml.go +++ b/lib/auth/saml.go @@ -21,8 +21,10 @@ import ( "compress/flate" "context" "encoding/base64" + "encoding/json" "fmt" "io" + "sync" "github.com/beevik/etree" "github.com/google/go-cmp/cmp" @@ -38,21 +40,37 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/events" "github.com/gravitational/teleport/lib/services" + "github.com/gravitational/teleport/lib/services/local" "github.com/gravitational/teleport/lib/utils" ) -// ErrSAMLNoRoles results from not mapping any roles from SAML claims. -var ErrSAMLNoRoles = trace.AccessDenied("No roles mapped from claims. The mappings may contain typos.") +// ErrSAMLRequiresEnterprise is the error returned by the SAML methods when not +// using the Enterprise edition of Teleport. +var ErrSAMLRequiresEnterprise = fmt.Errorf("SAML: %w", ErrRequiresEnterprise) + +// SAMLService are the methods that the auth server delegates to a plugin for +// implementing the SAML connector. These are the core functions of SAML +// authentication - the connector CRUD operations and Get methods are +// implemeneted in auth.Server and provide no connector-specific logic. +type SAMLService interface { + // CreateSAMLAuthRequest creates SAML AuthnRequest + CreateSAMLAuthRequest(ctx context.Context, req types.SAMLAuthRequest) (*types.SAMLAuthRequest, error) + // ValidateSAMLResponse validates SAML auth response + ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) +} // UpsertSAMLConnector creates or updates a SAML connector. func (a *Server) UpsertSAMLConnector(ctx context.Context, connector types.SAMLConnector) error { + // Validate the SAML connector here, because even though Services.UpsertSAMLConnector + // also validates, it does not have a RoleGetter to use to validate the roles, so + // has to pass `nil` for the second argument. if err := services.ValidateSAMLConnector(connector, a); err != nil { return trace.Wrap(err) } if err := a.Services.UpsertSAMLConnector(ctx, connector); err != nil { return trace.Wrap(err) } - if err := a.emitter.EmitAuditEvent(ctx, &apievents.OIDCConnectorCreate{ + if err := a.emitter.EmitAuditEvent(ctx, &apievents.SAMLConnectorCreate{ Metadata: apievents.Metadata{ Type: events.SAMLConnectorCreatedEvent, Code: events.SAMLConnectorCreatedCode, @@ -68,19 +86,19 @@ func (a *Server) UpsertSAMLConnector(ctx context.Context, connector types.SAMLCo return nil } -// DeleteSAMLConnector deletes a SAML connector by name. -func (a *Server) DeleteSAMLConnector(ctx context.Context, connectorName string) error { - if err := a.Services.DeleteSAMLConnector(ctx, connectorName); err != nil { +// DeleteSAMLConnector deletes a SAML connector. +func (a *Server) DeleteSAMLConnector(ctx context.Context, connectorID string) error { + if err := a.Services.DeleteSAMLConnector(ctx, connectorID); err != nil { return trace.Wrap(err) } - if err := a.emitter.EmitAuditEvent(ctx, &apievents.OIDCConnectorDelete{ + if err := a.emitter.EmitAuditEvent(ctx, &apievents.SAMLConnectorDelete{ Metadata: apievents.Metadata{ Type: events.SAMLConnectorDeletedEvent, Code: events.SAMLConnectorDeletedCode, }, UserMetadata: ClientUserMetadata(ctx), ResourceMetadata: apievents.ResourceMetadata{ - Name: connectorName, + Name: connectorID, }, }); err != nil { log.WithError(err).Warn("Failed to emit SAML connector delete event.") @@ -89,8 +107,69 @@ func (a *Server) DeleteSAMLConnector(ctx context.Context, connectorName string) return nil } +// CreateSAMLAuthRequest delegates the method call to the samlAuthService if present, +// or returns a NotImplemented error if not present. func (a *Server) CreateSAMLAuthRequest(ctx context.Context, req types.SAMLAuthRequest) (*types.SAMLAuthRequest, error) { - connector, provider, err := a.getSAMLConnectorAndProvider(ctx, req) + if a.samlAuthService == nil { + return nil, trace.Wrap(ErrSAMLRequiresEnterprise) + } + + rq, err := a.samlAuthService.CreateSAMLAuthRequest(ctx, req) + return rq, trace.Wrap(err) +} + +// ValidateSAMLResponse delegates the method call to the samlAuthService if present, +// or returns a NotImplemented error if not present. +func (a *Server) ValidateSAMLResponse(ctx context.Context, re string, connectorID string) (*SAMLAuthResponse, error) { + if a.samlAuthService == nil { + return nil, trace.Wrap(ErrSAMLRequiresEnterprise) + } + + resp, err := a.samlAuthService.ValidateSAMLResponse(ctx, re, connectorID) + return resp, trace.Wrap(err) +} + +// SAMLAuthService implements the logic of the SAML connector, allowing SSO +// logins using the SAML protocol. +// +// SAMLAuthService implements the SAMLService interface. +type SAMLAuthService struct { + auth *Server + emitter apievents.Emitter + assertionReplayService *local.AssertionReplayService + samlProviders map[string]*samlProvider + lock sync.Mutex +} + +type SAMLAuthServiceConfig struct { + Auth *Server + Emitter apievents.Emitter + AssertionReplayService *local.AssertionReplayService +} + +// NewSAMLAuthService returns a SAMLAuthService configured to use the +// services given in the config. +func NewSAMLAuthService(cfg *SAMLAuthServiceConfig) *SAMLAuthService { + return &SAMLAuthService{ + auth: cfg.Auth, + emitter: cfg.Emitter, + assertionReplayService: cfg.AssertionReplayService, + + samlProviders: make(map[string]*samlProvider), + } +} + +// samlProvider is internal structure that stores SAML client and its config +type samlProvider struct { + provider *saml2.SAMLServiceProvider + connector types.SAMLConnector +} + +// ErrSAMLNoRoles results from not mapping any roles from SAML claims. +var ErrSAMLNoRoles = trace.AccessDenied("No roles mapped from claims. The mappings may contain typos.") + +func (sas *SAMLAuthService) CreateSAMLAuthRequest(ctx context.Context, req types.SAMLAuthRequest) (*types.SAMLAuthRequest, error) { + connector, provider, err := sas.getSAMLConnectorAndProvider(ctx, req) if err != nil { return nil, trace.Wrap(err) } @@ -122,19 +201,19 @@ func (a *Server) CreateSAMLAuthRequest(ctx context.Context, req types.SAMLAuthRe return nil, trace.Wrap(err) } - err = a.Services.CreateSAMLAuthRequest(ctx, req, defaults.SAMLAuthRequestTTL) + err = sas.auth.Services.CreateSAMLAuthRequest(ctx, req, defaults.SAMLAuthRequestTTL) if err != nil { return nil, trace.Wrap(err) } return &req, nil } -func (a *Server) getSAMLConnectorAndProviderByID(ctx context.Context, connectorID string) (types.SAMLConnector, *saml2.SAMLServiceProvider, error) { - connector, err := a.Identity.GetSAMLConnector(ctx, connectorID, true) +func (sas *SAMLAuthService) getSAMLConnectorAndProviderByID(ctx context.Context, connectorID string) (types.SAMLConnector, *saml2.SAMLServiceProvider, error) { + connector, err := sas.auth.Identity.GetSAMLConnector(ctx, connectorID, true) if err != nil { return nil, nil, trace.Wrap(err) } - provider, err := a.getSAMLProvider(connector) + provider, err := sas.getSAMLProvider(connector) if err != nil { return nil, nil, trace.Wrap(err) } @@ -142,7 +221,7 @@ func (a *Server) getSAMLConnectorAndProviderByID(ctx context.Context, connectorI return connector, provider, nil } -func (a *Server) getSAMLConnectorAndProvider(ctx context.Context, req types.SAMLAuthRequest) (types.SAMLConnector, *saml2.SAMLServiceProvider, error) { +func (sas *SAMLAuthService) getSAMLConnectorAndProvider(ctx context.Context, req types.SAMLAuthRequest) (types.SAMLConnector, *saml2.SAMLServiceProvider, error) { if req.SSOTestFlow { if req.ConnectorSpec == nil { return nil, nil, trace.BadParameter("ConnectorSpec cannot be nil when SSOTestFlow is true") @@ -159,13 +238,13 @@ func (a *Server) getSAMLConnectorAndProvider(ctx context.Context, req types.SAML } // validate, set defaults for connector - err = services.ValidateSAMLConnector(connector, a) + err = services.ValidateSAMLConnector(connector, sas.auth) if err != nil { return nil, nil, trace.Wrap(err) } - // we don't want to cache the provider. construct it directly instead of using a.getSAMLProvider() - provider, err := services.GetSAMLServiceProvider(connector, a.clock) + // we don't want to cache the provider. construct it directly instead of using sas.getSAMLProvider() + provider, err := services.GetSAMLServiceProvider(connector, sas.auth.GetClock()) if err != nil { return nil, nil, trace.Wrap(err) } @@ -174,52 +253,52 @@ func (a *Server) getSAMLConnectorAndProvider(ctx context.Context, req types.SAML } // regular execution flow - return a.getSAMLConnectorAndProviderByID(ctx, req.ConnectorID) + return sas.getSAMLConnectorAndProviderByID(ctx, req.ConnectorID) } -func (a *Server) getSAMLProvider(conn types.SAMLConnector) (*saml2.SAMLServiceProvider, error) { - a.lock.Lock() - defer a.lock.Unlock() +func (sas *SAMLAuthService) getSAMLProvider(conn types.SAMLConnector) (*saml2.SAMLServiceProvider, error) { + sas.lock.Lock() + defer sas.lock.Unlock() - providerPack, ok := a.samlProviders[conn.GetName()] + providerPack, ok := sas.samlProviders[conn.GetName()] if ok && cmp.Equal(providerPack.connector, conn) { return providerPack.provider, nil } - delete(a.samlProviders, conn.GetName()) + delete(sas.samlProviders, conn.GetName()) - serviceProvider, err := services.GetSAMLServiceProvider(conn, a.clock) + serviceProvider, err := services.GetSAMLServiceProvider(conn, sas.auth.GetClock()) if err != nil { return nil, trace.Wrap(err) } - a.samlProviders[conn.GetName()] = &samlProvider{connector: conn, provider: serviceProvider} + sas.samlProviders[conn.GetName()] = &samlProvider{connector: conn, provider: serviceProvider} return serviceProvider, nil } -func (a *Server) calculateSAMLUser(diagCtx *ssoDiagContext, connector types.SAMLConnector, assertionInfo saml2.AssertionInfo, request *types.SAMLAuthRequest) (*createUserParams, error) { - p := createUserParams{ - connectorName: connector.GetName(), - username: assertionInfo.NameID, +func (sas *SAMLAuthService) calculateSAMLUser(diagCtx *SSODiagContext, connector types.SAMLConnector, assertionInfo saml2.AssertionInfo, request *types.SAMLAuthRequest) (*CreateUserParams, error) { + p := CreateUserParams{ + ConnectorName: connector.GetName(), + Username: assertionInfo.NameID, } - p.traits = services.SAMLAssertionsToTraits(assertionInfo) + p.Traits = services.SAMLAssertionsToTraits(assertionInfo) - diagCtx.info.SAMLTraitsFromAssertions = p.traits - diagCtx.info.SAMLConnectorTraitMapping = connector.GetTraitMappings() + diagCtx.Info.SAMLTraitsFromAssertions = p.Traits + diagCtx.Info.SAMLConnectorTraitMapping = connector.GetTraitMappings() var warnings []string - warnings, p.roles = services.TraitsToRoles(connector.GetTraitMappings(), p.traits) - if len(p.roles) == 0 { + warnings, p.Roles = services.TraitsToRoles(connector.GetTraitMappings(), p.Traits) + if len(p.Roles) == 0 { if len(warnings) != 0 { log.WithField("connector", connector).Warnf("No roles mapped from claims. Warnings: %q", warnings) - diagCtx.info.SAMLAttributesToRolesWarnings = &types.SSOWarnings{ + diagCtx.Info.SAMLAttributesToRolesWarnings = &types.SSOWarnings{ Message: "No roles mapped for the user", Warnings: warnings, } } else { log.WithField("connector", connector).Warnf("No roles mapped from claims.") - diagCtx.info.SAMLAttributesToRolesWarnings = &types.SSOWarnings{ + diagCtx.Info.SAMLAttributesToRolesWarnings = &types.SSOWarnings{ Message: "No roles mapped for the user. The mappings may contain typos.", } } @@ -227,52 +306,52 @@ func (a *Server) calculateSAMLUser(diagCtx *ssoDiagContext, connector types.SAML } // Pick smaller for role: session TTL from role or requested TTL. - roles, err := services.FetchRoles(p.roles, a, p.traits) + roles, err := services.FetchRoles(p.Roles, sas.auth, p.Traits) if err != nil { return nil, trace.Wrap(err) } roleTTL := roles.AdjustSessionTTL(apidefaults.MaxCertDuration) if request != nil { - p.sessionTTL = utils.MinTTL(roleTTL, request.CertTTL) + p.SessionTTL = utils.MinTTL(roleTTL, request.CertTTL) } else { - p.sessionTTL = roleTTL + p.SessionTTL = roleTTL } return &p, nil } -func (a *Server) createSAMLUser(p *createUserParams, dryRun bool) (types.User, error) { - expires := a.GetClock().Now().UTC().Add(p.sessionTTL) +func (sas *SAMLAuthService) createSAMLUser(p *CreateUserParams, dryRun bool) (types.User, error) { + expires := sas.auth.GetClock().Now().UTC().Add(p.SessionTTL) - log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v. Dry run: %v.", p.connectorName, p.username, p.roles, dryRun) + log.Debugf("Generating dynamic SAML identity %v/%v with roles: %v. Dry run: %v.", p.ConnectorName, p.Username, p.Roles, dryRun) user := &types.UserV2{ Kind: types.KindUser, Version: types.V2, Metadata: types.Metadata{ - Name: p.username, + Name: p.Username, Namespace: apidefaults.Namespace, Expires: &expires, }, Spec: types.UserSpecV2{ - Roles: p.roles, - Traits: p.traits, + Roles: p.Roles, + Traits: p.Traits, SAMLIdentities: []types.ExternalIdentity{ { - ConnectorID: p.connectorName, - Username: p.username, + ConnectorID: p.ConnectorName, + Username: p.Username, }, }, CreatedBy: types.CreatedBy{ User: types.UserRef{ Name: teleport.UserSystem, }, - Time: a.clock.Now().UTC(), + Time: sas.auth.GetClock().Now().UTC(), Connector: &types.ConnectorRef{ Type: constants.SAML, - ID: p.connectorName, - Identity: p.username, + ID: p.ConnectorName, + Identity: p.Username, }, }, }, @@ -283,7 +362,7 @@ func (a *Server) createSAMLUser(p *createUserParams, dryRun bool) (types.User, e } // Get the user to check if it already exists or not. - existingUser, err := a.Services.GetUser(p.username, false) + existingUser, err := sas.auth.Services.GetUser(p.Username, false) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) } @@ -303,11 +382,11 @@ func (a *Server) createSAMLUser(p *createUserParams, dryRun bool) (types.User, e log.Debugf("Overwriting existing user %q created with %v connector %v.", existingUser.GetName(), connectorRef.Type, connectorRef.ID) - if err := a.UpdateUser(ctx, user); err != nil { + if err := sas.auth.UpdateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } else { - if err := a.CreateUser(ctx, user); err != nil { + if err := sas.auth.CreateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } @@ -389,6 +468,33 @@ type SAMLAuthRequest struct { ClientRedirectURL string `json:"client_redirect_url"` } +// ValidateSAMLResponseReq is the request made by the proxy to validate +// and activate a login via SAML. +type ValidateSAMLResponseReq struct { + Response string `json:"response"` + ConnectorID string `json:"connector_id,omitempty"` +} + +// SAMLAuthRawResponse is returned when auth server validated callback parameters +// returned from SAML provider +type SAMLAuthRawResponse struct { + // Username is authenticated teleport username + Username string `json:"username"` + // Identity contains validated OIDC identity + Identity types.ExternalIdentity `json:"identity"` + // Web session will be generated by auth server if requested in OIDCAuthRequest + Session json.RawMessage `json:"session,omitempty"` + // Cert will be generated by certificate authority + Cert []byte `json:"cert,omitempty"` + // Req is original oidc auth request + Req SAMLAuthRequest `json:"req"` + // HostSigners is a list of signing host public keys + // trusted by proxy, used in console login + HostSigners []json.RawMessage `json:"host_signers"` + // TLSCert is TLS certificate authority certificate + TLSCert []byte `json:"tls_cert,omitempty"` +} + // SAMLAuthRequestFromProto converts the types.SAMLAuthRequest to SAMLAuthRequestData. func SAMLAuthRequestFromProto(req *types.SAMLAuthRequest) SAMLAuthRequest { return SAMLAuthRequest{ @@ -401,7 +507,7 @@ func SAMLAuthRequestFromProto(req *types.SAMLAuthRequest) SAMLAuthRequest { } // ValidateSAMLResponse consumes attribute statements from SAML identity provider -func (a *Server) ValidateSAMLResponse(ctx context.Context, samlResponse string, connectorID string) (*SAMLAuthResponse, error) { +func (sas *SAMLAuthService) ValidateSAMLResponse(ctx context.Context, samlResponse string, connectorID string) (*SAMLAuthResponse, error) { event := &apievents.UserLogin{ Metadata: apievents.Metadata{ Type: events.UserLoginEvent, @@ -409,14 +515,14 @@ func (a *Server) ValidateSAMLResponse(ctx context.Context, samlResponse string, Method: events.LoginMethodSAML, } - diagCtx := a.newSSODiagContext(types.KindSAML) + diagCtx := NewSSODiagContext(types.KindSAML, sas.auth) - auth, err := a.validateSAMLResponse(ctx, diagCtx, samlResponse, connectorID) - diagCtx.info.Error = trace.UserMessage(err) + auth, err := sas.validateSAMLResponse(ctx, diagCtx, samlResponse, connectorID) + diagCtx.Info.Error = trace.UserMessage(err) - diagCtx.writeToBackend(ctx) + diagCtx.WriteToBackend(ctx) - attributeStatements := diagCtx.info.SAMLAttributeStatements + attributeStatements := diagCtx.Info.SAMLAttributeStatements if attributeStatements != nil { attributes, err := apievents.EncodeMapStrings(attributeStatements) if err != nil { @@ -429,13 +535,13 @@ func (a *Server) ValidateSAMLResponse(ctx context.Context, samlResponse string, if err != nil { event.Code = events.UserSSOLoginFailureCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginFailureCode } event.Status.Success = false event.Status.Error = trace.Unwrap(err).Error() event.Status.UserMessage = err.Error() - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := sas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit SAML login failed event.") } return nil, trace.Wrap(err) @@ -444,18 +550,18 @@ func (a *Server) ValidateSAMLResponse(ctx context.Context, samlResponse string, event.Status.Success = true event.User = auth.Username event.Code = events.UserSSOLoginCode - if diagCtx.info.TestFlow { + if diagCtx.Info.TestFlow { event.Code = events.UserSSOTestFlowLoginCode } - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := sas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit SAML login event.") } return auth, nil } -func (a *Server) checkIDPInitiatedSAML(ctx context.Context, connector types.SAMLConnector, assertion *saml2.AssertionInfo) error { +func (sas *SAMLAuthService) checkIDPInitiatedSAML(ctx context.Context, connector types.SAMLConnector, assertion *saml2.AssertionInfo) error { if !connector.GetAllowIDPInitiated() { return trace.AccessDenied("IdP initiated SAML is not allowed by the connector configuration") } @@ -465,11 +571,11 @@ func (a *Server) checkIDPInitiatedSAML(ctx context.Context, connector types.SAML return nil } - err := a.unstable.RecognizeSSOAssertion(ctx, connector.GetName(), assertion.SessionIndex, assertion.NameID, *assertion.SessionNotOnOrAfter) + err := sas.assertionReplayService.RecognizeSSOAssertion(ctx, connector.GetName(), assertion.SessionIndex, assertion.NameID, *assertion.SessionNotOnOrAfter) return trace.Wrap(err) } -func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagContext, samlResponse string, connectorID string) (*SAMLAuthResponse, error) { +func (sas *SAMLAuthService) validateSAMLResponse(ctx context.Context, diagCtx *SSODiagContext, samlResponse string, connectorID string) (*SAMLAuthResponse, error) { idpInitiated := false var connector types.SAMLConnector var provider *saml2.SAMLServiceProvider @@ -482,21 +588,21 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte } idpInitiated = true - connector, provider, err = a.getSAMLConnectorAndProviderByID(ctx, connectorID) + connector, provider, err = sas.getSAMLConnectorAndProviderByID(ctx, connectorID) if err != nil { return nil, trace.Wrap(err, "Failed to get SAML connector and provider") } case err != nil: return nil, trace.Wrap(err) default: - diagCtx.requestID = requestID - request, err = a.Identity.GetSAMLAuthRequest(ctx, requestID) + diagCtx.RequestID = requestID + request, err = sas.auth.Identity.GetSAMLAuthRequest(ctx, requestID) if err != nil { return nil, trace.Wrap(err, "Failed to get SAML Auth Request") } - diagCtx.info.TestFlow = request.SSOTestFlow - connector, provider, err = a.getSAMLConnectorAndProvider(ctx, *request) + diagCtx.Info.TestFlow = request.SSOTestFlow + connector, provider, err = sas.getSAMLConnectorAndProvider(ctx, *request) if err != nil { return nil, trace.Wrap(err, "Failed to get SAML connector and provider") } @@ -509,11 +615,11 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte } if assertionInfo != nil { - diagCtx.info.SAMLAssertionInfo = (*types.AssertionInfo)(assertionInfo) + diagCtx.Info.SAMLAssertionInfo = (*types.AssertionInfo)(assertionInfo) } if idpInitiated { - if err := a.checkIDPInitiatedSAML(ctx, connector, assertionInfo); err != nil { + if err := sas.checkIDPInitiatedSAML(ctx, connector, assertionInfo); err != nil { if trace.IsAccessDenied(err) { log.Warnf("Failed to process IdP-initiated login request. IdP-initiated login is disabled for this connector: %v.", err) } @@ -546,8 +652,8 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte attributeStatements[key] = vals } - diagCtx.info.SAMLAttributeStatements = attributeStatements - diagCtx.info.SAMLAttributesToRoles = connector.GetAttributesToRoles() + diagCtx.Info.SAMLAttributeStatements = attributeStatements + diagCtx.Info.SAMLAttributesToRoles = connector.GetAttributesToRoles() if len(connector.GetAttributesToRoles()) == 0 { samlErr := trace.BadParameter("no attributes to roles mapping, check connector documentation") @@ -558,22 +664,22 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte // Calculate (figure out name, roles, traits, session TTL) of user and // create the user in the backend. - params, err := a.calculateSAMLUser(diagCtx, connector, *assertionInfo, request) + params, err := sas.calculateSAMLUser(diagCtx, connector, *assertionInfo, request) if err != nil { return nil, trace.Wrap(err, "Failed to calculate user attributes.") } - diagCtx.info.CreateUserParams = &types.CreateUserParams{ - ConnectorName: params.connectorName, - Username: params.username, - KubeGroups: params.kubeGroups, - KubeUsers: params.kubeUsers, - Roles: params.roles, - Traits: params.traits, - SessionTTL: types.Duration(params.sessionTTL), + diagCtx.Info.CreateUserParams = &types.CreateUserParams{ + ConnectorName: params.ConnectorName, + Username: params.Username, + KubeGroups: params.KubeGroups, + KubeUsers: params.KubeUsers, + Roles: params.Roles, + Traits: params.Traits, + SessionTTL: types.Duration(params.SessionTTL), } - user, err := a.createSAMLUser(params, request != nil && request.SSOTestFlow) + user, err := sas.createSAMLUser(params, request != nil && request.SSOTestFlow) if err != nil { return nil, trace.Wrap(err, "Failed to create user from provided parameters.") } @@ -581,8 +687,8 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte // Auth was successful, return session, certificate, etc. to caller. auth := &SAMLAuthResponse{ Identity: types.ExternalIdentity{ - ConnectorID: params.connectorName, - Username: params.username, + ConnectorID: params.ConnectorName, + Username: params.Username, }, Username: user.GetName(), } @@ -597,18 +703,18 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte // In test flow skip signing and creating web sessions. if request != nil && request.SSOTestFlow { - diagCtx.info.Success = true + diagCtx.Info.Success = true return auth, nil } // If the request is coming from a browser, create a web session. if request == nil || request.CreateWebSession { - session, err := a.createWebSession(ctx, types.NewWebSessionRequest{ + session, err := sas.auth.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user.GetName(), Roles: user.GetRoles(), Traits: user.GetTraits(), - SessionTTL: params.sessionTTL, - LoginTime: a.clock.Now().UTC(), + SessionTTL: params.SessionTTL, + LoginTime: sas.auth.GetClock().Now().UTC(), }) if err != nil { return nil, trace.Wrap(err, "Failed to create web session.") @@ -619,12 +725,12 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte // If a public key was provided, sign it and return a certificate. if request != nil && len(request.PublicKey) != 0 { - sshCert, tlsCert, err := a.createSessionCert(user, params.sessionTTL, request.PublicKey, request.Compatibility, request.RouteToCluster, + sshCert, tlsCert, err := sas.auth.CreateSessionCert(user, params.SessionTTL, request.PublicKey, request.Compatibility, request.RouteToCluster, request.KubernetesCluster, keys.AttestationStatementFromProto(request.AttestationStatement)) if err != nil { return nil, trace.Wrap(err, "Failed to create session certificate.") } - clusterName, err := a.GetClusterName() + clusterName, err := sas.auth.GetClusterName() if err != nil { return nil, trace.Wrap(err, "Failed to obtain cluster name.") } @@ -632,7 +738,7 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte auth.TLSCert = tlsCert // Return the host CA for this cluster only. - authority, err := a.GetCertAuthority(ctx, types.CertAuthID{ + authority, err := sas.auth.GetCertAuthority(ctx, types.CertAuthID{ Type: types.HostCA, DomainName: clusterName.GetClusterName(), }, false) @@ -642,6 +748,6 @@ func (a *Server) validateSAMLResponse(ctx context.Context, diagCtx *ssoDiagConte auth.HostSigners = append(auth.HostSigners, authority) } - diagCtx.info.Success = true + diagCtx.Info.Success = true return auth, nil } diff --git a/lib/auth/saml_test.go b/lib/auth/saml_test.go index e62600469ebdc..6f384de385fb4 100644 --- a/lib/auth/saml_test.go +++ b/lib/auth/saml_test.go @@ -76,12 +76,15 @@ func TestCreateSAMLUser(t *testing.T) { a, err := NewServer(authConfig) require.NoError(t, err) + sas, ok := a.samlAuthService.(*SAMLAuthService) + require.True(t, ok, "Server.samlAuthServer is not type *samlAuthServer") + // Dry-run creation of SAML user. - user, err := a.createSAMLUser(&createUserParams{ - connectorName: "samlService", - username: "foo@example.com", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + user, err := sas.createSAMLUser(&CreateUserParams{ + ConnectorName: "samlService", + Username: "foo@example.com", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, true) require.NoError(t, err) require.Equal(t, "foo@example.com", user.GetName()) @@ -91,11 +94,11 @@ func TestCreateSAMLUser(t *testing.T) { require.Error(t, err) // Create SAML user with 1 minute expiry. - _, err = a.createSAMLUser(&createUserParams{ - connectorName: "samlService", - username: "foo@example.com", - roles: []string{"admin"}, - sessionTTL: 1 * time.Minute, + _, err = sas.createSAMLUser(&CreateUserParams{ + ConnectorName: "samlService", + Username: "foo@example.com", + Roles: []string{"admin"}, + SessionTTL: 1 * time.Minute, }, false) require.NoError(t, err) @@ -304,6 +307,9 @@ func TestServer_getConnectorAndProvider(t *testing.T) { a, err := NewServer(authConfig) require.NoError(t, err) + sas, ok := a.samlAuthService.(*SAMLAuthService) + require.True(t, ok, "Server.samlAuthServer is not type *samlAuthServer") + _, err = CreateRole(ctx, a, "baz", types.RoleSpecV5{}) require.NoError(t, err) @@ -352,7 +358,7 @@ func TestServer_getConnectorAndProvider(t *testing.T) { }, } - connector, provider, err := a.getSAMLConnectorAndProvider(context.Background(), request) + connector, provider, err := sas.getSAMLConnectorAndProvider(context.Background(), request) require.NoError(t, err) require.NotNil(t, connector) require.NotNil(t, provider) @@ -387,7 +393,7 @@ func TestServer_getConnectorAndProvider(t *testing.T) { SSOTestFlow: false, } - connector, provider, err = a.getSAMLConnectorAndProvider(context.Background(), request2) + connector, provider, err = sas.getSAMLConnectorAndProvider(context.Background(), request2) require.NoError(t, err) require.NotNil(t, connector) require.NotNil(t, provider) @@ -423,10 +429,11 @@ func TestServer_ValidateSAMLResponse(t *testing.T) { }, } - a, err := NewServer(authConfig) + a, err := NewServer(authConfig, WithClock(clock)) require.NoError(t, err) - a.SetClock(clock) + sas, ok := a.samlAuthService.(*SAMLAuthService) + require.True(t, ok, "Server.samlAuthServer is not type *samlAuthServer") // empty response gives error. response, err := a.ValidateSAMLResponse(context.Background(), "", "") @@ -556,19 +563,19 @@ V115UGOwvjOOxmOFbYBn865SHgMndFtr require.NoError(t, err) // check ValidateSAMLResponse - response, err = a.ValidateSAMLResponse(context.Background(), base64.StdEncoding.EncodeToString([]byte(respOkta)), "") + response, err = sas.ValidateSAMLResponse(context.Background(), base64.StdEncoding.EncodeToString([]byte(respOkta)), "") require.NoError(t, err) require.NotNil(t, response) // check internal method, validate diagnostic outputs. - diagCtx := a.newSSODiagContext(types.KindSAML) - auth, err := a.validateSAMLResponse(context.Background(), diagCtx, base64.StdEncoding.EncodeToString([]byte(respOkta)), "") + diagCtx := NewSSODiagContext(types.KindSAML, a) + auth, err := sas.validateSAMLResponse(context.Background(), diagCtx, base64.StdEncoding.EncodeToString([]byte(respOkta)), "") require.NoError(t, err) // ensure diag info got stored and is identical. infoFromBackend, err := a.GetSSODiagnosticInfo(context.Background(), types.KindSAML, auth.Req.ID) require.NoError(t, err) - require.Equal(t, &diagCtx.info, infoFromBackend) + require.Equal(t, &diagCtx.Info, infoFromBackend) // verify values require.Equal(t, "ops@gravitational.io", auth.Username) @@ -580,8 +587,8 @@ V115UGOwvjOOxmOFbYBn865SHgMndFtr authnInstant := time.Date(2022, 4, 25, 8, 3, 11, 779000000, time.UTC) // ignore, this is boring and very complex. - require.NotNil(t, diagCtx.info.SAMLAssertionInfo.Assertions) - diagCtx.info.SAMLAssertionInfo.Assertions = nil + require.NotNil(t, diagCtx.Info.SAMLAssertionInfo.Assertions) + diagCtx.Info.SAMLAssertionInfo.Assertions = nil require.Equal(t, types.SSODiagnosticInfo{ TestFlow: true, @@ -661,7 +668,7 @@ V115UGOwvjOOxmOFbYBn865SHgMndFtr Roles: []string{"access"}, }, }, - }, diagCtx.info) + }, diagCtx.Info) // make sure no users have been created. users, err := a.GetUsers(false) diff --git a/lib/auth/sessions.go b/lib/auth/sessions.go index c10325859dc53..ebf9c62a78307 100644 --- a/lib/auth/sessions.go +++ b/lib/auth/sessions.go @@ -240,7 +240,7 @@ func (s *Server) generateAppToken(ctx context.Context, username string, roles [] return token, nil } -func (s *Server) createWebSession(ctx context.Context, req types.NewWebSessionRequest) (types.WebSession, error) { +func (s *Server) CreateWebSessionFromReq(ctx context.Context, req types.NewWebSessionRequest) (types.WebSession, error) { session, err := s.NewWebSession(ctx, req) if err != nil { return nil, trace.Wrap(err) @@ -254,7 +254,7 @@ func (s *Server) createWebSession(ctx context.Context, req types.NewWebSessionRe return session, nil } -func (s *Server) createSessionCert(user types.User, sessionTTL time.Duration, publicKey []byte, compatibility, routeToCluster, kubernetesCluster string, attestationReq *keys.AttestationStatement) ([]byte, []byte, error) { +func (s *Server) CreateSessionCert(user types.User, sessionTTL time.Duration, publicKey []byte, compatibility, routeToCluster, kubernetesCluster string, attestationReq *keys.AttestationStatement) ([]byte, []byte, error) { // It's safe to extract the access info directly from services.User because // this occurs during the initial login before the first certs have been // generated, so there's no possibility of any active access requests. diff --git a/lib/auth/sso_diag_context.go b/lib/auth/sso_diag_context.go index 6f271cf1c8685..ad939c056618a 100644 --- a/lib/auth/sso_diag_context.go +++ b/lib/auth/sso_diag_context.go @@ -22,33 +22,48 @@ import ( "github.com/gravitational/teleport/api/types" ) -// ssoDiagContext is a helper type for accumulating the SSO diagnostic info prior to writing it to the backend. -type ssoDiagContext struct { - // authKind is auth kind such as types.KindSAML - authKind string - // createSSODiagnosticInfo is a callback to create the types.SSODiagnosticInfo record in the backend. - createSSODiagnosticInfo func(ctx context.Context, authKind string, authRequestID string, info types.SSODiagnosticInfo) error - // requestID is the ID of the auth request being processed. - requestID string - // info accumulates SSO diagnostic info - info types.SSODiagnosticInfo +// SSODiagContext is a helper type for accumulating the SSO diagnostic info prior to writing it to the backend. +type SSODiagContext struct { + // AuthKind is auth kind such as types.KindSAML + AuthKind string + // DiagService is the SSODiagService that will record our diagnostic info in the backend. + DiagService SSODiagService + // RequestID is the ID of the auth request being processed. + RequestID string + // Info accumulates SSO diagnostic Info + Info types.SSODiagnosticInfo } -// writeToBackend saves the accumulated SSO diagnostic information to the backend. -func (c *ssoDiagContext) writeToBackend(ctx context.Context) { - if c.info.TestFlow { - err := c.createSSODiagnosticInfo(ctx, c.authKind, c.requestID, c.info) +// SSODiagService is a thin slice of services.Identity required by SSODiagContext +// to record the SSO diagnostic info in a store. +type SSODiagService interface { + // CreateSSODiagnosticInfo creates new SSO diagnostic info record. + CreateSSODiagnosticInfo(ctx context.Context, authKind string, authRequestID string, entry types.SSODiagnosticInfo) error +} + +// SSODiagServiceFunc is an adaptor allowing a function to be used in place +// of the SSODiagService interface. +type SSODiagServiceFunc func(ctx context.Context, authKind string, authRequestID string, entry types.SSODiagnosticInfo) error + +func (f SSODiagServiceFunc) CreateSSODiagnosticInfo(ctx context.Context, authKind string, authRequestID string, entry types.SSODiagnosticInfo) error { + return f(ctx, authKind, authRequestID, entry) +} + +// WriteToBackend saves the accumulated SSO diagnostic information to the backend. +func (c *SSODiagContext) WriteToBackend(ctx context.Context) { + if c.Info.TestFlow { + err := c.DiagService.CreateSSODiagnosticInfo(ctx, c.AuthKind, c.RequestID, c.Info) if err != nil { - log.WithError(err).WithField("requestID", c.requestID).Warn("failed to write SSO diag info data") + log.WithError(err).WithField("requestID", c.RequestID).Warn("failed to write SSO diag info data") } } } -// newSSODiagContext returns new ssoDiagContext referencing particular Server. +// NewSSODiagContext returns new ssoDiagContext referencing particular Server. // authKind must be one of supported auth kinds (e.g. types.KindSAML). -func (a *Server) newSSODiagContext(authKind string) *ssoDiagContext { - return &ssoDiagContext{ - authKind: authKind, - createSSODiagnosticInfo: a.CreateSSODiagnosticInfo, +func NewSSODiagContext(authKind string, diagSvc SSODiagService) *SSODiagContext { + return &SSODiagContext{ + AuthKind: authKind, + DiagService: diagSvc, } } diff --git a/lib/auth/sso_diag_context_test.go b/lib/auth/sso_diag_context_test.go index e29d5d5db767a..49abcc9e3c9f4 100644 --- a/lib/auth/sso_diag_context_test.go +++ b/lib/auth/sso_diag_context_test.go @@ -26,30 +26,31 @@ import ( ) func Test_ssoDiagContext_writeToBackend(t *testing.T) { - diag := &ssoDiagContext{ - authKind: types.KindSAML, - requestID: "123", - info: types.SSODiagnosticInfo{}, + diag := &SSODiagContext{ + AuthKind: types.KindSAML, + RequestID: "123", + Info: types.SSODiagnosticInfo{}, } callCount := 0 - diag.createSSODiagnosticInfo = func(ctx context.Context, authKind string, authRequestID string, info types.SSODiagnosticInfo) error { + diagFn := func(ctx context.Context, authKind string, authRequestID string, info types.SSODiagnosticInfo) error { callCount++ - require.Truef(t, info.TestFlow, "createSSODiagnosticInfo must not be called if info.TestFlow is false.") - require.Equal(t, diag.authKind, authKind) - require.Equal(t, diag.requestID, authRequestID) - require.Equal(t, diag.info, info) + require.Truef(t, info.TestFlow, "CreateSSODiagnosticInfo must not be called if info.TestFlow is false.") + require.Equal(t, diag.AuthKind, authKind) + require.Equal(t, diag.RequestID, authRequestID) + require.Equal(t, diag.Info, info) return nil } + diag.DiagService = SSODiagServiceFunc(diagFn) // with TestFlow: false, no call is made. - diag.info.TestFlow = false - diag.writeToBackend(context.Background()) + diag.Info.TestFlow = false + diag.WriteToBackend(context.Background()) require.Equal(t, 0, callCount) // with TestFlow: true, a call is made. - diag.info.TestFlow = true - diag.writeToBackend(context.Background()) + diag.Info.TestFlow = true + diag.WriteToBackend(context.Background()) require.Equal(t, 1, callCount) } diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index d495ce1af3191..d2aafea862fa8 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -85,9 +85,9 @@ import ( ) const ( - // ssoLoginConsoleErr is a generic error message to hide revealing sso login failure msgs. - ssoLoginConsoleErr = "Failed to login. Please check Teleport's log for more details." - metaRedirectHTML = ` + // SSOLoginFailureMessage is a generic error message to avoid disclosing sensitive SSO failure messages. + SSOLoginFailureMessage = "Failed to login. Please check Teleport's log for more details." + metaRedirectHTML = ` @@ -193,10 +193,10 @@ type Config struct { // Enables web UI if set. StaticFS http.FileSystem - // cachedSessionLingeringThreshold specifies the time the session will linger + // CachedSessionLingeringThreshold specifies the time the session will linger // in the cache before getting purged after it has expired. // Defaults to cachedSessionLingeringThreshold if unspecified. - cachedSessionLingeringThreshold *time.Duration + CachedSessionLingeringThreshold *time.Duration // ClusterFeatures contains flags for supported/unsupported features. ClusterFeatures proto.Features @@ -292,8 +292,8 @@ func NewHandler(cfg Config, opts ...HandlerOption) (*APIHandler, error) { } sessionLingeringThreshold := cachedSessionLingeringThreshold - if cfg.cachedSessionLingeringThreshold != nil { - sessionLingeringThreshold = *cfg.cachedSessionLingeringThreshold + if cfg.CachedSessionLingeringThreshold != nil { + sessionLingeringThreshold = *cfg.CachedSessionLingeringThreshold } sessionCache, err := newSessionCache(h.cfg.Context, sessionCacheOptions{ @@ -1333,17 +1333,17 @@ func (h *Handler) oidcLoginWeb(w http.ResponseWriter, r *http.Request, p httprou logger := h.log.WithField("auth", "oidc") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateOIDCAuthRequest(r.Context(), types.OIDCAuthRequest{ - CSRFToken: req.csrfToken, - ConnectorID: req.connectorID, + CSRFToken: req.CSRFToken, + ConnectorID: req.ConnectorID, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, CheckUser: true, ProxyAddress: r.Host, }) @@ -1359,17 +1359,17 @@ func (h *Handler) githubLoginWeb(w http.ResponseWriter, r *http.Request, p httpr logger := h.log.WithField("auth", "github") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{ - CSRFToken: req.csrfToken, - ConnectorID: req.connectorID, + CSRFToken: req.CSRFToken, + ConnectorID: req.ConnectorID, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, }) if err != nil { logger.WithError(err).Error("Error creating auth request.") @@ -1387,12 +1387,12 @@ func (h *Handler) githubLoginConsole(w http.ResponseWriter, r *http.Request, p h req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } response, err := h.cfg.ProxyClient.CreateGithubAuthRequest(r.Context(), types.GithubAuthRequest{ @@ -1407,7 +1407,7 @@ func (h *Handler) githubLoginConsole(w http.ResponseWriter, r *http.Request, p h }) if err != nil { logger.WithError(err).Error("Failed to create Github auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } return &client.SSOLoginConsoleResponse{ @@ -1429,7 +1429,7 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID := r.URL.Query().Get("state"); requestID != "" { if request, errGet := h.cfg.ProxyClient.GetGithubAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if redURL, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if redURL, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return redURL.String() } } @@ -1445,19 +1445,19 @@ func (h *Handler) githubCallback(w http.ResponseWriter, r *http.Request, p httpr if response.Req.CreateWebSession { logger.Infof("Redirecting to web browser.") - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: response.Req.ClientRedirectURL, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: response.Req.ClientRedirectURL, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, true); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, true); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Infof("Callback is redirecting to console login.") @@ -1491,12 +1491,12 @@ func (h *Handler) oidcLoginConsole(w http.ResponseWriter, r *http.Request, p htt req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } response, err := h.cfg.ProxyClient.CreateOIDCAuthRequest(r.Context(), types.OIDCAuthRequest{ @@ -1513,7 +1513,7 @@ func (h *Handler) oidcLoginConsole(w http.ResponseWriter, r *http.Request, p htt }) if err != nil { logger.WithError(err).Error("Failed to create OIDC auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } return &client.SSOLoginConsoleResponse{ @@ -1535,7 +1535,7 @@ func (h *Handler) oidcCallback(w http.ResponseWriter, r *http.Request, p httprou // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID := r.URL.Query().Get("state"); requestID != "" { if request, errGet := h.cfg.ProxyClient.GetOIDCAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if redURL, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if redURL, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return redURL.String() } } @@ -1552,19 +1552,19 @@ func (h *Handler) oidcCallback(w http.ResponseWriter, r *http.Request, p httprou if response.Req.CreateWebSession { logger.Info("Redirecting to web browser.") - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: response.Req.ClientRedirectURL, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: response.Req.ClientRedirectURL, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, true); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, true); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Info("Callback redirecting to console login.") @@ -1715,7 +1715,11 @@ func ConstructSSHResponse(response AuthParams) (*url.URL, error) { return u, nil } -func redirectURLWithError(clientRedirectURL string, errReply error) (*url.URL, error) { +// RedirectURLWithError adds an err query parameter to the given redirect URL with the +// given errReply message and returns the new URL. If the given URL cannot be parsed, +// an error is returned with a nil URL. It is used to return an error back to the +// original URL in an SSO callback when validation fails. +func RedirectURLWithError(clientRedirectURL string, errReply error) (*url.URL, error) { u, err := url.Parse(clientRedirectURL) if err != nil { return nil, trace.Wrap(err) @@ -3578,13 +3582,23 @@ func makeTeleportClientConfig(ctx context.Context, sctx *SessionContext) (*clien return config, nil } -type ssoRequestParams struct { - clientRedirectURL string - connectorID string - csrfToken string +// SSORequestParams holds parameters parsed out of a HTTP request initiating an +// SSO login. See ParseSSORequestParams(). +type SSORequestParams struct { + // ClientRedirectURL is the URL specified in the query parameter + // redirect_url, which will be unescaped here. + ClientRedirectURL string + // ConnectorID identifies the SSO connector to use to log in, from + // the connector_id query parameter. + ConnectorID string + // CSRFToken is the token in the CSRF cookie header. + CSRFToken string } -func parseSSORequestParams(r *http.Request) (*ssoRequestParams, error) { +// ParseSSORequestParams extracts the SSO request parameters from an http.Request, +// returning them in an SSORequestParams struct. If any fields are not present, +// an error is returned. +func ParseSSORequestParams(r *http.Request) (*SSORequestParams, error) { // Manually grab the value from query param "redirect_url". // // The "redirect_url" param can contain its own query params such as in @@ -3616,37 +3630,52 @@ func parseSSORequestParams(r *http.Request) (*ssoRequestParams, error) { return nil, trace.Wrap(err) } - return &ssoRequestParams{ - clientRedirectURL: clientRedirectURL, - connectorID: connectorID, - csrfToken: csrfToken, + return &SSORequestParams{ + ClientRedirectURL: clientRedirectURL, + ConnectorID: connectorID, + CSRFToken: csrfToken, }, nil } -type ssoCallbackResponse struct { - csrfToken string - username string - sessionName string - clientRedirectURL string +// SSOCallbackResponse holds the parameters for validating and executing an SSO +// callback URL. See SSOSetWebSessionAndRedirectURL(). +type SSOCallbackResponse struct { + // CSRFToken is the token provided in the originating SSO login request + // to be validated against. + CSRFToken string + // Username is the authenticated teleport username of the user that has + // logged in, provided by the SSO provider. + Username string + // SessionName is the name of the session generated by auth server if + // requested in the SSO request. + SessionName string + // ClientRedirectURL is the URL to redirect back to on completion of + // the SSO login process. + ClientRedirectURL string } -func ssoSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, response *ssoCallbackResponse, verifyCSRF bool) error { +// SSOSetWebSessionAndRedirectURL validates the CSRF token in the response +// against that in the request, validates that the callback URL in the response +// can be parsed, and sets a session cookie with the username and session name +// from the response. On success, nil is returned. If the validation fails, an +// error is returned. +func SSOSetWebSessionAndRedirectURL(w http.ResponseWriter, r *http.Request, response *SSOCallbackResponse, verifyCSRF bool) error { if verifyCSRF { - if err := csrf.VerifyToken(response.csrfToken, r); err != nil { + if err := csrf.VerifyToken(response.CSRFToken, r); err != nil { return trace.Wrap(err) } } - if err := SetSessionCookie(w, response.username, response.sessionName); err != nil { + if err := SetSessionCookie(w, response.Username, response.SessionName); err != nil { return trace.Wrap(err) } - parsedURL, err := url.Parse(response.clientRedirectURL) + parsedURL, err := url.Parse(response.ClientRedirectURL) if err != nil { return trace.Wrap(err) } - response.clientRedirectURL = parsedURL.RequestURI() + response.ClientRedirectURL = parsedURL.RequestURI() return nil } diff --git a/lib/web/apiserver_test.go b/lib/web/apiserver_test.go index 5892cd9dae79e..49ce5fd4c7f22 100644 --- a/lib/web/apiserver_test.go +++ b/lib/web/apiserver_test.go @@ -437,7 +437,7 @@ func newWebSuiteWithConfig(t *testing.T, cfg webSuiteConfig) *WebSuite { HostUUID: proxyID, Emitter: s.proxyClient, StaticFS: fs, - cachedSessionLingeringThreshold: &sessionLingeringThreshold, + CachedSessionLingeringThreshold: &sessionLingeringThreshold, ProxySettings: &mockProxySettings{}, SessionControl: proxySessionController, Router: router, @@ -4839,33 +4839,33 @@ func TestParseSSORequestParams(t *testing.T) { tests := []struct { name, url string wantErr bool - expected *ssoRequestParams + expected *SSORequestParams }{ { name: "preserve redirect's query params (escaped)", url: "https://localhost/login?connector_id=oidc&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fnodes%3Fsearch=tunnel&sort=hostname:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - connectorID: "oidc", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", + ConnectorID: "oidc", + CSRFToken: token, }, }, { name: "preserve redirect's query params (unescaped)", url: "https://localhost/login?connector_id=github&redirect_url=https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", - connectorID: "github", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/nodes?search=tunnel&sort=hostname:asc", + ConnectorID: "github", + CSRFToken: token, }, }, { name: "preserve various encoded chars", url: "https://localhost/login?connector_id=saml&redirect_url=https:%2F%2Flocalhost:8080%2Fweb%2Fcluster%2Fim-a-cluster-name%2Fapps%3Fquery=search(%2522watermelon%2522%252C%2520%2522this%2522)%2520%2526%2526%2520labels%255B%2522unique-id%2522%255D%2520%253D%253D%2520%2522hi%2522&sort=name:asc", - expected: &ssoRequestParams{ - clientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", - connectorID: "saml", - csrfToken: token, + expected: &SSORequestParams{ + ClientRedirectURL: "https://localhost:8080/web/cluster/im-a-cluster-name/apps?query=search(%22watermelon%22%2C%20%22this%22)%20%26%26%20labels%5B%22unique-id%22%5D%20%3D%3D%20%22hi%22&sort=name:asc", + ConnectorID: "saml", + CSRFToken: token, }, }, { @@ -4886,7 +4886,7 @@ func TestParseSSORequestParams(t *testing.T) { require.NoError(t, err) addCSRFCookieToReq(req, token) - params, err := parseSSORequestParams(req) + params, err := ParseSSORequestParams(req) switch { case tc.wantErr: diff --git a/lib/web/saml.go b/lib/web/saml.go index 9d80b3bcd5bb7..c7eba7ee84028 100644 --- a/lib/web/saml.go +++ b/lib/web/saml.go @@ -34,17 +34,17 @@ func (h *Handler) samlSSO(w http.ResponseWriter, r *http.Request, p httprouter.P logger := h.log.WithField("auth", "saml") logger.Debug("Web login start.") - req, err := parseSSORequestParams(r) + req, err := ParseSSORequestParams(r) if err != nil { logger.WithError(err).Error("Failed to extract SSO parameters from request.") return client.LoginFailedRedirectURL } response, err := h.cfg.ProxyClient.CreateSAMLAuthRequest(r.Context(), types.SAMLAuthRequest{ - ConnectorID: req.connectorID, - CSRFToken: req.csrfToken, + ConnectorID: req.ConnectorID, + CSRFToken: req.CSRFToken, CreateWebSession: true, - ClientRedirectURL: req.clientRedirectURL, + ClientRedirectURL: req.ClientRedirectURL, }) if err != nil { logger.WithError(err).Error("Error creating auth request.") @@ -61,12 +61,12 @@ func (h *Handler) samlSSOConsole(w http.ResponseWriter, r *http.Request, p httpr req := new(client.SSOLoginConsoleReq) if err := httplib.ReadJSON(r, req); err != nil { logger.WithError(err).Error("Error reading json.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } if err := req.CheckAndSetDefaults(); err != nil { logger.WithError(err).Error("Missing request parameters.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } response, err := h.cfg.ProxyClient.CreateSAMLAuthRequest(r.Context(), types.SAMLAuthRequest{ @@ -81,7 +81,7 @@ func (h *Handler) samlSSOConsole(w http.ResponseWriter, r *http.Request, p httpr }) if err != nil { logger.WithError(err).Error("Failed to create SAML auth request.") - return nil, trace.AccessDenied(ssoLoginConsoleErr) + return nil, trace.AccessDenied(SSOLoginFailureMessage) } return &client.SSOLoginConsoleResponse{RedirectURL: response.RedirectURL}, nil @@ -108,7 +108,7 @@ func (h *Handler) samlACS(w http.ResponseWriter, r *http.Request, p httprouter.P // this improves the UX by terminating the failed SSO flow immediately, rather than hoping for a timeout. if requestID, errParse := auth.ParseSAMLInResponseTo(samlResponse); errParse == nil { if request, errGet := h.cfg.ProxyClient.GetSAMLAuthRequest(r.Context(), requestID); errGet == nil && !request.CreateWebSession { - if url, errEnc := redirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { + if url, errEnc := RedirectURLWithError(request.ClientRedirectURL, err); errEnc == nil { return url.String() } } @@ -130,19 +130,19 @@ func (h *Handler) samlACS(w http.ResponseWriter, r *http.Request, p httprouter.P redirect = "/web/" } - res := &ssoCallbackResponse{ - csrfToken: response.Req.CSRFToken, - username: response.Username, - sessionName: response.Session.GetName(), - clientRedirectURL: redirect, + res := &SSOCallbackResponse{ + CSRFToken: response.Req.CSRFToken, + Username: response.Username, + SessionName: response.Session.GetName(), + ClientRedirectURL: redirect, } - if err := ssoSetWebSessionAndRedirectURL(w, r, res, response.Req.CSRFToken != ""); err != nil { + if err := SSOSetWebSessionAndRedirectURL(w, r, res, response.Req.CSRFToken != ""); err != nil { logger.WithError(err).Error("Error setting web session.") return client.LoginFailedRedirectURL } - return res.clientRedirectURL + return res.ClientRedirectURL } logger.Debug("Callback redirecting to console login.")