diff --git a/lib/auth/apiserver.go b/lib/auth/apiserver.go index 9d7fdbf0410ae..7e4e9e2606943 100644 --- a/lib/auth/apiserver.go +++ b/lib/auth/apiserver.go @@ -719,32 +719,8 @@ func (s *APIServer) deleteCertAuthority(auth ClientI, w http.ResponseWriter, r * return message(fmt.Sprintf("cert '%v' deleted", id)), nil } -type validateOIDCAuthCallbackReq struct { - Query url.Values `json:"query"` -} - -// oidcAuthRawResponse is returned when auth server validated callback parameters -// returned from OIDC provider -type oidcAuthRawResponse struct { - // Username is authenticated teleport username - Username string `json:"username"` - // Identity contains validated OIDC identity - Identity types.ExternalIdentity `json:"identity"` - // Web session will be generated by auth server if requested in OIDCAuthRequest - Session json.RawMessage `json:"session,omitempty"` - // Cert will be generated by certificate authority - Cert []byte `json:"cert,omitempty"` - // TLSCert is PEM encoded TLS certificate - TLSCert []byte `json:"tls_cert,omitempty"` - // Req is original oidc auth request - Req OIDCAuthRequest `json:"req"` - // HostSigners is a list of signing host public keys - // trusted by proxy, used in console login - HostSigners []json.RawMessage `json:"host_signers"` -} - func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) { - var req *validateOIDCAuthCallbackReq + var req *ValidateOIDCAuthCallbackReq if err := httplib.ReadJSON(r, &req); err != nil { return nil, trace.Wrap(err) } @@ -752,7 +728,7 @@ func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter if err != nil { return nil, trace.Wrap(err) } - raw := oidcAuthRawResponse{ + raw := OIDCAuthRawResponse{ Username: response.Username, Identity: response.Identity, Cert: response.Cert, diff --git a/lib/auth/auth.go b/lib/auth/auth.go index 4abcfbfb33d8f..60f063fe5c86d 100644 --- a/lib/auth/auth.go +++ b/lib/auth/auth.go @@ -35,15 +35,12 @@ import ( "math/big" insecurerand "math/rand" "net" - "net/url" "os" "strings" "sync" "time" - "github.com/coreos/go-oidc/jose" "github.com/coreos/go-oidc/oauth2" - "github.com/coreos/go-oidc/oidc" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" @@ -244,7 +241,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { Authority: cfg.Authority, AuthServiceName: cfg.AuthServiceName, ServerID: cfg.HostUUID, - oidcClients: make(map[string]*oidcClient), githubClients: make(map[string]*githubClient), cancelFunc: cancelFunc, closeCtx: closeCtx, @@ -254,7 +250,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { Services: services, Cache: services, keyStore: keyStore, - getClaimsFun: getClaims, inventory: inventory.NewController(cfg.Presence), traceClient: cfg.TraceClient, fips: cfg.FIPS, @@ -291,6 +286,15 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) { } } + oas, err := NewOIDCAuthService(&OIDCAuthServiceConfig{ + Auth: &as, + Emitter: as.emitter, + }) + if err != nil { + return nil, trace.Wrap(err) + } + as.SetOIDCService(oas) + return &as, nil } @@ -394,9 +398,7 @@ var ( // - same for users and their sessions // - checks public keys to see if they're signed by it (can be trusted or not) type Server struct { - lock sync.RWMutex - // oidcClients is a map from authID & proxyAddr -> oidcClient - oidcClients map[string]*oidcClient + lock sync.RWMutex githubClients map[string]*githubClient clock clockwork.Clock bk backend.Backend @@ -405,6 +407,7 @@ type Server struct { cancelFunc context.CancelFunc samlAuthService SAMLService + oidcAuthService OIDCService sshca.Authority @@ -455,9 +458,6 @@ type Server struct { // lockWatcher is a lock watcher, used to verify cert generation requests. lockWatcher *services.LockWatcher - // getClaimsFun is used in tests for overriding the implementation of getClaims method used in OIDC. - getClaimsFun func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) - inventory *inventory.Controller // githubOrgSSOCache is used to cache whether Github organizations use @@ -483,13 +483,20 @@ type Server struct { loadAllCAs bool } -// SetSAMLService registers ss as the SAMLService that provides the SAML +// SetSAMLService registers svc as the SAMLService that provides the SAML // connector implementation. If a SAMLService has already been registered, this // will override the previous registration. func (a *Server) SetSAMLService(svc SAMLService) { a.samlAuthService = svc } +// SetOIDCService registers svc as the OIDCService that provides the OIDC +// connector implementation. If a OIDCService has already been registered, this +// will override the previous registration. +func (a *Server) SetOIDCService(svc OIDCService) { + a.oidcAuthService = svc +} + func (a *Server) CloseContext() context.Context { return a.closeCtx } @@ -3990,17 +3997,6 @@ const ( SessionTokenBytes = 32 ) -// oidcClient is internal structure that stores OIDC client and its config -type oidcClient struct { - client *oidc.Client - connector types.OIDCConnector - // syncCtx controls the provider sync goroutine. - syncCtx context.Context - syncCancel context.CancelFunc - // firstSync will be closed once the first provider sync succeeds - firstSync chan struct{} -} - // githubClient is internal structure that stores Github OAuth 2client and its config type githubClient struct { client *oauth2.Client @@ -4038,19 +4034,6 @@ func oauth2ConfigsEqual(a, b oauth2.Config) bool { return true } -// isHTTPS checks if the scheme for a URL is https or not. -func isHTTPS(u string) error { - earl, err := url.Parse(u) - if err != nil { - return trace.Wrap(err) - } - if earl.Scheme != "https" { - return trace.BadParameter("expected scheme https, got %q", earl.Scheme) - } - - return nil -} - // WithClusterCAs returns a TLS hello callback that returns a copy of the provided // TLS config with client CAs pool of the specified cluster. func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) { diff --git a/lib/auth/auth_test.go b/lib/auth/auth_test.go index 31c2df6ecb895..a51b3f89f03aa 100644 --- a/lib/auth/auth_test.go +++ b/lib/auth/auth_test.go @@ -21,7 +21,6 @@ import ( "crypto/rand" "crypto/rsa" "crypto/x509/pkix" - "encoding/json" "errors" "fmt" mathrand "math/rand" @@ -30,7 +29,6 @@ import ( "testing" "time" - "github.com/coreos/go-oidc/jose" "github.com/google/go-cmp/cmp" "github.com/google/uuid" reporting "github.com/gravitational/reporting/types" @@ -754,107 +752,6 @@ func TestGenerateTokenEventsEmitted(t *testing.T) { require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.TrustedClusterTokenCreateEvent) } -func TestValidateACRValues(t *testing.T) { - s := newAuthSuite(t) - - tests := []struct { - comment string - inIDToken string - inACRValue string - inACRProvider string - outIsValid require.ErrorAssertionFunc - }{ - { - "0 - default, acr values match", - ` -{ - "acr": "foo", - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo", - "", - require.NoError, - }, - { - "1 - default, acr values do not match", - ` -{ - "acr": "foo", - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "bar", - "", - require.Error, - }, - { - "2 - netiq, acr values match", - ` -{ - "acr": { - "values": [ - "foo/bar/baz" - ] - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.NoError, - }, - { - "3 - netiq, invalid format", - ` -{ - "acr": { - "values": "foo/bar/baz" - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.Error, - }, - { - "4 - netiq, invalid value", - ` -{ - "acr": { - "values": [ - "foo/bar/baz/qux" - ] - }, - "aud": "00000000-0000-0000-0000-000000000000", - "exp": 1111111111 -} - `, - "foo/bar/baz", - "netiq", - require.Error, - }, - } - - for _, tt := range tests { - tt := tt - t.Run(tt.comment, func(t *testing.T) { - t.Parallel() - var claims jose.Claims - err := json.Unmarshal([]byte(tt.inIDToken), &claims) - require.NoError(t, err) - - err = s.a.validateACRValues(tt.inACRValue, tt.inACRProvider, claims) - tt.outIsValid(t, err) - }) - } -} - func TestUpdateConfig(t *testing.T) { t.Parallel() s := newAuthSuite(t) diff --git a/lib/auth/clt.go b/lib/auth/clt.go index da9b271556499..0fff53ec56951 100644 --- a/lib/auth/clt.go +++ b/lib/auth/clt.go @@ -873,13 +873,13 @@ func (c *Client) GenerateHostCert( // ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { - out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), validateOIDCAuthCallbackReq{ + out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{ Query: q, }) if err != nil { return nil, trace.Wrap(err) } - var rawResponse *oidcAuthRawResponse + var rawResponse *OIDCAuthRawResponse if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil { return nil, trace.Wrap(err) } diff --git a/lib/auth/oidc.go b/lib/auth/oidc.go index b1396d3c15fdc..2232ceaf7457d 100644 --- a/lib/auth/oidc.go +++ b/lib/auth/oidc.go @@ -23,6 +23,7 @@ import ( "io" "net/http" "net/url" + "sync" "time" "github.com/coreos/go-oidc/jose" @@ -44,12 +45,148 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +type OIDCService interface { + CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) + ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) +} + +var errOIDCNotImplemented = trace.AccessDenied("OIDC is only available in enterprise subscriptions") + +func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { + if a.oidcAuthService == nil { + return nil, errOIDCNotImplemented + } + + rq, err := a.oidcAuthService.CreateOIDCAuthRequest(ctx, req) + return rq, trace.Wrap(err) +} + +func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { + if a.oidcAuthService == nil { + return nil, errOIDCNotImplemented + } + + resp, err := a.oidcAuthService.ValidateOIDCAuthCallback(ctx, q) + return resp, trace.Wrap(err) +} + +// OIDCAuthResponse is returned when auth server validated callback parameters +// returned from OIDC provider +type OIDCAuthResponse struct { + // Username is authenticated teleport username + Username string `json:"username"` + // Identity contains validated OIDC identity + Identity types.ExternalIdentity `json:"identity"` + // Web session will be generated by auth server if requested in OIDCAuthRequest + Session types.WebSession `json:"session,omitempty"` + // Cert will be generated by certificate authority + Cert []byte `json:"cert,omitempty"` + // TLSCert is PEM encoded TLS certificate + TLSCert []byte `json:"tls_cert,omitempty"` + // Req is original oidc auth request + Req OIDCAuthRequest `json:"req"` + // HostSigners is a list of signing host public keys + // trusted by proxy, used in console login + HostSigners []types.CertAuthority `json:"host_signers"` +} + +// OIDCAuthRequest is an OIDC auth request that supports standard json marshaling. +type OIDCAuthRequest struct { + // ConnectorID is ID of OIDC connector this request uses + ConnectorID string `json:"connector_id"` + // CSRFToken is associated with user web session token + CSRFToken string `json:"csrf_token"` + // PublicKey is an optional public key, users want these + // keys to be signed by auth servers user CA in case + // of successful auth + PublicKey []byte `json:"public_key"` + // CreateWebSession indicates if user wants to generate a web + // session after successful authentication + CreateWebSession bool `json:"create_web_session"` + // ClientRedirectURL is a URL client wants to be redirected + // after successful authentication + ClientRedirectURL string `json:"client_redirect_url"` +} + +// ValidateOIDCAuthCallbackReq is the request made by the proxy to validate +// and activate a login via OIDC. +type ValidateOIDCAuthCallbackReq struct { + Query url.Values `json:"query"` +} + +// OIDCAuthRawResponse is returned when auth server validated callback parameters +// returned from OIDC provider +type OIDCAuthRawResponse struct { + // Username is authenticated teleport username + Username string `json:"username"` + // Identity contains validated OIDC identity + Identity types.ExternalIdentity `json:"identity"` + // Web session will be generated by auth server if requested in OIDCAuthRequest + Session json.RawMessage `json:"session,omitempty"` + // Cert will be generated by certificate authority + Cert []byte `json:"cert,omitempty"` + // TLSCert is PEM encoded TLS certificate + TLSCert []byte `json:"tls_cert,omitempty"` + // Req is original oidc auth request + Req OIDCAuthRequest `json:"req"` + // HostSigners is a list of signing host public keys + // trusted by proxy, used in console login + HostSigners []json.RawMessage `json:"host_signers"` +} + +type OIDCAuthService struct { + auth *Server + emitter apievents.Emitter + clients map[string]*oidcClient + lock sync.Mutex + getClaimsFun func(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) +} + +type OIDCAuthServiceConfig struct { + Auth *Server + Emitter apievents.Emitter +} + +func (cfg *OIDCAuthServiceConfig) CheckAndSetDefaults() error { + if cfg.Auth == nil { + return trace.BadParameter("auth.Server not provided") + } + if cfg.Emitter == nil { + cfg.Emitter = events.NewDiscardEmitter() + } + return nil +} + +func NewOIDCAuthService(cfg *OIDCAuthServiceConfig) (*OIDCAuthService, error) { + if err := cfg.CheckAndSetDefaults(); err != nil { + return nil, err + } + + return &OIDCAuthService{ + auth: cfg.Auth, + emitter: cfg.Emitter, + clients: make(map[string]*oidcClient), + getClaimsFun: getClaims, + }, nil +} + +// oidcClient is internal structure that stores OIDC client and its config +type oidcClient struct { + client *oidc.Client + connector types.OIDCConnector + // syncCtx controls the provider sync goroutine. + syncCtx context.Context + syncCancel context.CancelFunc + // firstSync will be closed once the first provider sync succeeds + firstSync chan struct{} +} + // ErrOIDCNoRoles results from not mapping any roles from OIDC claims. var ErrOIDCNoRoles = trace.AccessDenied("No roles mapped from claims. The mappings may contain typos.") // getOIDCConnectorAndClient returns the associated oidc connector // and client for the given oidc auth request. -func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OIDCAuthRequest) (types.OIDCConnector, *oidc.Client, error) { +func (oas *OIDCAuthService) getOIDCConnectorAndClient(ctx context.Context, request types.OIDCAuthRequest) (types.OIDCConnector, *oidc.Client, error) { // stateless test flow if request.SSOTestFlow { if request.ConnectorSpec == nil { @@ -76,7 +213,7 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI // close this request-scoped oidc client after 10 minutes go func() { - ticker := a.GetClock().NewTicker(defaults.OIDCAuthRequestTTL) + ticker := oas.auth.GetClock().NewTicker(defaults.OIDCAuthRequestTTL) defer ticker.Stop() select { case <-ticker.Chan(): @@ -89,12 +226,12 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI } // regular execution flow - connector, err := a.GetOIDCConnector(ctx, request.ConnectorID, true) + connector, err := oas.auth.GetOIDCConnector(ctx, request.ConnectorID, true) if err != nil { return nil, nil, trace.Wrap(err) } - client, err := a.getCachedOIDCClient(ctx, connector, request.ProxyAddress) + client, err := oas.getCachedOIDCClient(ctx, connector, request.ProxyAddress) if err != nil { return nil, nil, trace.Wrap(err) } @@ -110,22 +247,22 @@ func (a *Server) getOIDCConnectorAndClient(ctx context.Context, request types.OI // getCachedOIDCClient gets a cached oidc client for // the given OIDC connector and redirectURL preference. -func (a *Server) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnector, proxyAddr string) (*oidcClient, error) { - a.lock.Lock() - defer a.lock.Unlock() +func (oas *OIDCAuthService) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnector, proxyAddr string) (*oidcClient, error) { + oas.lock.Lock() + defer oas.lock.Unlock() // Each connector and proxy combination has a distinct client, // so we use a composite key to capture all combinations. clientMapKey := conn.GetName() + "_" + proxyAddr - cachedClient, ok := a.oidcClients[clientMapKey] + cachedClient, ok := oas.clients[clientMapKey] if ok { if !cachedClient.needsRefresh(conn) && cachedClient.syncCtx.Err() == nil { return cachedClient, nil } // Cached client needs to be refreshed or is no longer syncing. cachedClient.syncCancel() - delete(a.oidcClients, clientMapKey) + delete(oas.clients, clientMapKey) } // Create a new oidc client and add it to the cache. @@ -134,7 +271,7 @@ func (a *Server) getCachedOIDCClient(ctx context.Context, conn types.OIDCConnect return nil, trace.Wrap(err) } - a.oidcClients[clientMapKey] = client + oas.clients[clientMapKey] = client return client, nil } @@ -259,12 +396,12 @@ func (a *Server) DeleteOIDCConnector(ctx context.Context, connectorName string) return nil } -func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { +func (oas *OIDCAuthService) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRequest) (*types.OIDCAuthRequest, error) { // ensure prompt removal of OIDC client in test flows. does nothing in regular flows. ctx, cancel := context.WithCancel(ctx) defer cancel() - connector, client, err := a.getOIDCConnectorAndClient(ctx, req) + connector, client, err := oas.getOIDCConnectorAndClient(ctx, req) if err != nil { return nil, trace.Wrap(err) } @@ -299,7 +436,7 @@ func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRe log.Debugf("OIDC redirect URL: %v.", req.RedirectURL) - err = a.Services.CreateOIDCAuthRequest(ctx, req, defaults.OIDCAuthRequestTTL) + err = oas.auth.Services.CreateOIDCAuthRequest(ctx, req, defaults.OIDCAuthRequestTTL) if err != nil { return nil, trace.Wrap(err) } @@ -309,7 +446,7 @@ func (a *Server) CreateOIDCAuthRequest(ctx context.Context, req types.OIDCAuthRe // ValidateOIDCAuthCallback is called by the proxy to check OIDC query parameters // returned by OIDC Provider, if everything checks out, auth server // will respond with OIDCAuthResponse, otherwise it will return error -func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { +func (oas *OIDCAuthService) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) { event := &apievents.UserLogin{ Metadata: apievents.Metadata{ Type: events.UserLoginEvent, @@ -317,9 +454,9 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O Method: events.LoginMethodOIDC, } - diagCtx := NewSSODiagContext(types.KindOIDC, a) + diagCtx := NewSSODiagContext(types.KindOIDC, oas.auth) - auth, err := a.validateOIDCAuthCallback(ctx, diagCtx, q) + auth, err := oas.validateOIDCAuthCallback(ctx, diagCtx, q) diagCtx.Info.Error = trace.UserMessage(err) diagCtx.WriteToBackend(ctx) @@ -344,7 +481,7 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O event.Status.Error = trace.Unwrap(err).Error() event.Status.UserMessage = err.Error() - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := oas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit OIDC login failed event.") } @@ -358,7 +495,7 @@ func (a *Server) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*O event.User = auth.Username event.Status.Success = true - if err := a.emitter.EmitAuditEvent(a.closeCtx, event); err != nil { + if err := oas.emitter.EmitAuditEvent(ctx, event); err != nil { log.WithError(err).Warn("Failed to emit OIDC login event.") } @@ -398,13 +535,13 @@ func checkEmailVerifiedClaim(claims jose.Claims) error { return nil } -func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*OIDCAuthResponse, error) { +func (oas *OIDCAuthService) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagContext, q url.Values) (*OIDCAuthResponse, error) { if errParam := q.Get("error"); errParam != "" { // try to find request so the error gets logged against it. state := q.Get("state") if state != "" { diagCtx.RequestID = state - req, err := a.GetOIDCAuthRequest(ctx, state) + req, err := oas.auth.GetOIDCAuthRequest(ctx, state) if err == nil { diagCtx.Info.TestFlow = req.SSOTestFlow } @@ -428,7 +565,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC } diagCtx.RequestID = stateToken - req, err := a.GetOIDCAuthRequest(ctx, stateToken) + req, err := oas.auth.GetOIDCAuthRequest(ctx, stateToken) if err != nil { return nil, trace.Wrap(err, "Failed to get OIDC Auth Request.") } @@ -438,13 +575,13 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC ctxC, cancel := context.WithCancel(ctx) defer cancel() - connector, client, err := a.getOIDCConnectorAndClient(ctxC, *req) + connector, client, err := oas.getOIDCConnectorAndClient(ctxC, *req) if err != nil { return nil, trace.Wrap(err, "Failed to get OIDC connector and client.") } // extract claims from both the id token and the userinfo endpoint and merge them - claims, err := a.getClaims(client, connector, code) + claims, err := oas.getClaims(ctx, client, connector, code) if err != nil { // different error message for Google Workspace as likely cause is different. if isGoogleWorkspaceConnector(connector) { @@ -465,7 +602,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC // if we are sending acr values, make sure we also validate them acrValue := connector.GetACR() if acrValue != "" { - err := a.validateACRValues(acrValue, connector.GetProvider(), claims) + err := validateACRValues(acrValue, connector.GetProvider(), claims) if err != nil { return nil, trace.Wrap(err, "OIDC ACR validation failure.") } @@ -494,7 +631,7 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC // Calculate (figure out name, roles, traits, session TTL) of user and // create the user in the backend. - params, err := a.calculateOIDCUser(diagCtx, connector, claims, ident, req) + params, err := oas.calculateOIDCUser(diagCtx, connector, claims, ident, req) if err != nil { return nil, trace.Wrap(err, "Failed to calculate user attributes.") } @@ -509,13 +646,13 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC SessionTTL: types.Duration(params.SessionTTL), } - user, err := a.createOIDCUser(params, req.SSOTestFlow) + user, err := oas.createOIDCUser(params, req.SSOTestFlow) if err != nil { return nil, trace.Wrap(err, "Failed to create user from provided parameters.") } // Auth was successful, return session, certificate, etc. to caller. - auth := &OIDCAuthResponse{ + resp := &OIDCAuthResponse{ Req: OIDCAuthRequestFromProto(req), Identity: types.ExternalIdentity{ ConnectorID: params.ConnectorName, @@ -527,93 +664,55 @@ func (a *Server) validateOIDCAuthCallback(ctx context.Context, diagCtx *SSODiagC // In test flow skip signing and creating web sessions. if req.SSOTestFlow { diagCtx.Info.Success = true - return auth, nil + return resp, nil } if !req.CheckUser { - return auth, nil + return resp, nil } // If the request is coming from a browser, create a web session. if req.CreateWebSession { - session, err := a.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ + session, err := oas.auth.CreateWebSessionFromReq(ctx, types.NewWebSessionRequest{ User: user.GetName(), Roles: user.GetRoles(), Traits: user.GetTraits(), SessionTTL: params.SessionTTL, - LoginTime: a.clock.Now().UTC(), + LoginTime: oas.auth.GetClock().Now().UTC(), }) if err != nil { return nil, trace.Wrap(err, "Failed to create web session.") } - auth.Session = session + resp.Session = session } // If a public key was provided, sign it and return a certificate. if len(req.PublicKey) != 0 { - sshCert, tlsCert, err := a.CreateSessionCert(user, params.SessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, + sshCert, tlsCert, err := oas.auth.CreateSessionCert(user, params.SessionTTL, req.PublicKey, req.Compatibility, req.RouteToCluster, req.KubernetesCluster, keys.AttestationStatementFromProto(req.AttestationStatement)) if err != nil { return nil, trace.Wrap(err, "Failed to create session certificate.") } - clusterName, err := a.GetClusterName() + clusterName, err := oas.auth.GetClusterName() if err != nil { return nil, trace.Wrap(err, "Failed to obtain cluster name.") } - auth.Cert = sshCert - auth.TLSCert = tlsCert + resp.Cert = sshCert + resp.TLSCert = tlsCert // Return the host CA for this cluster only. - authority, err := a.GetCertAuthority(ctx, types.CertAuthID{ + authority, err := oas.auth.GetCertAuthority(ctx, types.CertAuthID{ Type: types.HostCA, DomainName: clusterName.GetClusterName(), }, false) if err != nil { return nil, trace.Wrap(err, "Failed to obtain cluster's host CA.") } - auth.HostSigners = append(auth.HostSigners, authority) + resp.HostSigners = append(resp.HostSigners, authority) } - return auth, nil -} - -// OIDCAuthResponse is returned when auth server validated callback parameters -// returned from OIDC provider -type OIDCAuthResponse struct { - // Username is authenticated teleport username - Username string `json:"username"` - // Identity contains validated OIDC identity - Identity types.ExternalIdentity `json:"identity"` - // Web session will be generated by auth server if requested in OIDCAuthRequest - Session types.WebSession `json:"session,omitempty"` - // Cert will be generated by certificate authority - Cert []byte `json:"cert,omitempty"` - // TLSCert is PEM encoded TLS certificate - TLSCert []byte `json:"tls_cert,omitempty"` - // Req is original oidc auth request - Req OIDCAuthRequest `json:"req"` - // HostSigners is a list of signing host public keys - // trusted by proxy, used in console login - HostSigners []types.CertAuthority `json:"host_signers"` -} - -// OIDCAuthRequest is an OIDC auth request that supports standard json marshaling. -type OIDCAuthRequest struct { - // ConnectorID is ID of OIDC connector this request uses - ConnectorID string `json:"connector_id"` - // CSRFToken is associated with user web session token - CSRFToken string `json:"csrf_token"` - // PublicKey is an optional public key, users want these - // keys to be signed by auth servers user CA in case - // of successful auth - PublicKey []byte `json:"public_key"` - // CreateWebSession indicates if user wants to generate a web - // session after successful authentication - CreateWebSession bool `json:"create_web_session"` - // ClientRedirectURL is a URL client wants to be redirected - // after successful authentication - ClientRedirectURL string `json:"client_redirect_url"` + return resp, nil } // OIDCAuthRequestFromProto converts the types.OIDCAuthRequest to OIDCAuthRequest. @@ -627,7 +726,7 @@ func OIDCAuthRequestFromProto(req *types.OIDCAuthRequest) OIDCAuthRequest { } } -func (a *Server) calculateOIDCUser(diagCtx *SSODiagContext, connector types.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *types.OIDCAuthRequest) (*CreateUserParams, error) { +func (oas *OIDCAuthService) calculateOIDCUser(diagCtx *SSODiagContext, connector types.OIDCConnector, claims jose.Claims, ident *oidc.Identity, request *types.OIDCAuthRequest) (*CreateUserParams, error) { var err error username, err := usernameFromClaims(connector, claims, ident) @@ -664,7 +763,7 @@ func (a *Server) calculateOIDCUser(diagCtx *SSODiagContext, connector types.OIDC } // Pick smaller for role: session TTL from role or requested TTL. - roles, err := services.FetchRoles(p.Roles, a, p.Traits) + roles, err := services.FetchRoles(p.Roles, oas.auth, p.Traits) if err != nil { return nil, trace.Wrap(err) } @@ -674,8 +773,8 @@ func (a *Server) calculateOIDCUser(diagCtx *SSODiagContext, connector types.OIDC return &p, nil } -func (a *Server) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, error) { - expires := a.GetClock().Now().UTC().Add(p.SessionTTL) +func (oas *OIDCAuthService) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, error) { + expires := oas.auth.GetClock().Now().UTC().Add(p.SessionTTL) log.Debugf("Generating dynamic OIDC identity %v/%v with roles: %v. Dry run: %v.", p.ConnectorName, p.Username, p.Roles, dryRun) user := &types.UserV2{ @@ -697,7 +796,7 @@ func (a *Server) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, e }, CreatedBy: types.CreatedBy{ User: types.UserRef{Name: teleport.UserSystem}, - Time: a.clock.Now().UTC(), + Time: oas.auth.GetClock().Now().UTC(), Connector: &types.ConnectorRef{ Type: constants.OIDC, ID: p.ConnectorName, @@ -712,7 +811,7 @@ func (a *Server) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, e } // Get the user to check if it already exists or not. - existingUser, err := a.Services.GetUser(p.Username, false) + existingUser, err := oas.auth.Services.GetUser(p.Username, false) if err != nil && !trace.IsNotFound(err) { return nil, trace.Wrap(err) } @@ -732,11 +831,11 @@ func (a *Server) createOIDCUser(p *CreateUserParams, dryRun bool) (types.User, e log.Debugf("Overwriting existing user %q created with %v connector %v.", existingUser.GetName(), connectorRef.Type, connectorRef.ID) - if err := a.UpdateUser(ctx, user); err != nil { + if err := oas.auth.UpdateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } else { - if err := a.CreateUser(ctx, user); err != nil { + if err := oas.auth.CreateUser(ctx, user); err != nil { return nil, trace.Wrap(err) } } @@ -874,12 +973,12 @@ func mergeClaims(a jose.Claims, b jose.Claims) (jose.Claims, error) { } // getClaims gets claims from ID token and UserInfo and returns UserInfo claims merged into ID token claims. -func (a *Server) getClaims(oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { - return a.getClaimsFun(a.closeCtx, oidcClient, connector, code) +func (oas *OIDCAuthService) getClaims(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { + return oas.getClaimsFun(ctx, oidcClient, connector, code) } -// getClaims implements Server.getClaims, but allows that code path to be overridden for testing. -func getClaims(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { +// getClaims implements OIDCAuthService.getClaims, but allows that code path to be overridden for testing. +func getClaims(ctx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { oac, err := getOAuthClient(oidcClient, connector) if err != nil { return nil, trace.Wrap(err) @@ -944,7 +1043,7 @@ func getClaims(closeCtx context.Context, oidcClient *oidc.Client, connector type } if isGoogleWorkspaceConnector(connector) { - claims, err = addGoogleWorkspaceClaims(closeCtx, connector, claims) + claims, err = addGoogleWorkspaceClaims(ctx, connector, claims) if err != nil { return nil, trace.Wrap(err) } @@ -974,7 +1073,7 @@ func getOAuthClient(oidcClient *oidc.Client, connector types.OIDCConnector) (*oa // validateACRValues validates that we get an appropriate response for acr values. By default // we expect the same value we send, but this function also handles Identity Provider specific // forms of validation. -func (a *Server) validateACRValues(acrValue string, identityProvider string, claims jose.Claims) error { +func validateACRValues(acrValue string, identityProvider string, claims jose.Claims) error { switch identityProvider { case teleport.NetIQ: log.Debugf("Validating OIDC ACR values with '%v' rules.", identityProvider) @@ -1029,3 +1128,16 @@ func (a *Server) validateACRValues(acrValue string, identityProvider string, cla return nil } + +// isHTTPS checks if the scheme for a URL is https or not. +func isHTTPS(u string) error { + earl, err := url.Parse(u) + if err != nil { + return trace.Wrap(err) + } + if earl.Scheme != "https" { + return trace.BadParameter("expected scheme https, got %q", earl.Scheme) + } + + return nil +} diff --git a/lib/auth/oidc_test.go b/lib/auth/oidc_test.go index c7c07e42ba740..78a96be4e4e70 100644 --- a/lib/auth/oidc_test.go +++ b/lib/auth/oidc_test.go @@ -51,9 +51,10 @@ import ( ) type OIDCSuite struct { - a *Server - b backend.Backend - c clockwork.FakeClock + a *Server + b backend.Backend + c clockwork.FakeClock + oas *OIDCAuthService } func setUpSuite(t *testing.T) *OIDCSuite { @@ -87,6 +88,11 @@ func setUpSuite(t *testing.T) *OIDCSuite { } s.a, err = NewServer(authConfig) require.NoError(t, err) + + var ok bool + s.oas, ok = s.a.oidcAuthService.(*OIDCAuthService) + require.True(t, ok, "Server.oidcAuthService is not type *OIDCAuthService") + return &s } @@ -112,7 +118,7 @@ func TestCreateOIDCUser(t *testing.T) { s := setUpSuite(t) // Dry-run creation of OIDC user. - user, err := s.a.createOIDCUser(&CreateUserParams{ + user, err := s.oas.createOIDCUser(&CreateUserParams{ ConnectorName: "oidcService", Username: "foo@example.com", Roles: []string{"admin"}, @@ -126,7 +132,7 @@ func TestCreateOIDCUser(t *testing.T) { require.Error(t, err) // Create OIDC user with 1 minute expiry. - _, err = s.a.createOIDCUser(&CreateUserParams{ + _, err = s.oas.createOIDCUser(&CreateUserParams{ ConnectorName: "oidcService", Username: "foo@example.com", Roles: []string{"admin"}, @@ -153,6 +159,7 @@ func TestUserInfoBlockHTTP(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -166,7 +173,7 @@ func TestUserInfoBlockHTTP(t *testing.T) { }) require.NoError(t, err) - oidcClient, err := s.a.getCachedOIDCClient(ctx, connector, "") + oidcClient, err := s.oas.getCachedOIDCClient(ctx, connector, "") require.NoError(t, err) // Verify HTTP endpoints return trace.NotFound. @@ -232,6 +239,7 @@ func TestSSODiagnostic(t *testing.T) { t.Run(tc.name, func(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -274,7 +282,7 @@ func TestSSODiagnostic(t *testing.T) { } // override getClaimsFun. - s.a.getClaimsFun = func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { + s.oas.getClaimsFun = func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error) { cc := map[string]interface{}{ "email_verified": true, "groups": []string{"everyone", "idp-admin", "idp-dev"}, @@ -285,7 +293,7 @@ func TestSSODiagnostic(t *testing.T) { return cc, nil } - resp, err := s.a.ValidateOIDCAuthCallback(ctx, values) + resp, err := s.oas.ValidateOIDCAuthCallback(ctx, values) if tc.wantValidateErr != nil { require.ErrorIs(t, err, tc.wantValidateErr) return @@ -304,7 +312,7 @@ func TestSSODiagnostic(t *testing.T) { diagCtx := SSODiagContext{} - resp, err = s.a.validateOIDCAuthCallback(ctx, &diagCtx, values) + resp, err = s.oas.validateOIDCAuthCallback(ctx, &diagCtx, values) require.NoError(t, err) require.NotNil(t, resp) require.Equal(t, &OIDCAuthResponse{ @@ -377,6 +385,7 @@ func TestPingProvider(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) @@ -410,7 +419,7 @@ func TestPingProvider(t *testing.T) { }, } { t.Run(fmt.Sprintf("Test SSOFlow: %v", req.SSOTestFlow), func(t *testing.T) { - oidcConnector, oidcClient, err := s.a.getOIDCConnectorAndClient(ctx, req) + oidcConnector, oidcClient, err := s.oas.getOIDCConnectorAndClient(ctx, req) require.NoError(t, err) oac, err := getOAuthClient(oidcClient, oidcConnector) @@ -487,6 +496,7 @@ func TestOIDCClientCache(t *testing.T) { ctx := context.Background() s := setUpSuite(t) + // Create configurable IdP to use in tests. idp := newFakeIDP(t, false /* tls */) connectorSpec := types.OIDCConnectorSpecV3{ @@ -501,17 +511,17 @@ func TestOIDCClientCache(t *testing.T) { require.NoError(t, err) // Create and cache a new oidc client - client, err := s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + client, err := s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) // The next call should return the same client (compare memory address) - cachedClient, err := s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + cachedClient, err := s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) require.True(t, client == cachedClient) // Canceling provider sync on a cached client should cause it to be replaced client.syncCancel() - cachedClient, err = s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + cachedClient, err = s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) require.False(t, client == cachedClient) @@ -560,12 +570,12 @@ func TestOIDCClientCache(t *testing.T) { require.NoError(t, err) tc.mutateConnector(newConnector) - client, err = s.a.getCachedOIDCClient(ctx, newConnector, "proxy.example.com") + client, err = s.oas.getCachedOIDCClient(ctx, newConnector, "proxy.example.com") require.NoError(t, err) require.True(t, (client == originalClient) == tc.expectNoRefresh) // reset cached client to the original client for remaining tests - originalClient, err = s.a.getCachedOIDCClient(ctx, connector, "proxy.example.com") + originalClient, err = s.oas.getCachedOIDCClient(ctx, connector, "proxy.example.com") require.NoError(t, err) }) } @@ -880,7 +890,7 @@ func TestUsernameClaim(t *testing.T) { require.NoError(t, err) // Generate the userCreateParams for the OIDC user. - createUserParams, err := s.a.calculateOIDCUser(&diagCtx, connector, claims, ident, request) + createUserParams, err := s.oas.calculateOIDCUser(&diagCtx, connector, claims, ident, request) if tc.expectedError != "" { require.ErrorContains(t, err, tc.expectedError) } else { @@ -890,3 +900,102 @@ func TestUsernameClaim(t *testing.T) { }) } } + +func TestValidateACRValues(t *testing.T) { + tests := []struct { + comment string + inIDToken string + inACRValue string + inACRProvider string + outIsValid require.ErrorAssertionFunc + }{ + { + "0 - default, acr values match", + ` +{ + "acr": "foo", + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo", + "", + require.NoError, + }, + { + "1 - default, acr values do not match", + ` +{ + "acr": "foo", + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "bar", + "", + require.Error, + }, + { + "2 - netiq, acr values match", + ` +{ + "acr": { + "values": [ + "foo/bar/baz" + ] + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.NoError, + }, + { + "3 - netiq, invalid format", + ` +{ + "acr": { + "values": "foo/bar/baz" + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.Error, + }, + { + "4 - netiq, invalid value", + ` +{ + "acr": { + "values": [ + "foo/bar/baz/qux" + ] + }, + "aud": "00000000-0000-0000-0000-000000000000", + "exp": 1111111111 +} + `, + "foo/bar/baz", + "netiq", + require.Error, + }, + } + + for _, tt := range tests { + tt := tt + t.Run(tt.comment, func(t *testing.T) { + t.Parallel() + var claims jose.Claims + err := json.Unmarshal([]byte(tt.inIDToken), &claims) + require.NoError(t, err) + + err = validateACRValues(tt.inACRValue, tt.inACRProvider, claims) + tt.outIsValid(t, err) + }) + } +}