Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 2 additions & 26 deletions lib/auth/apiserver.go
Original file line number Diff line number Diff line change
Expand Up @@ -719,40 +719,16 @@ func (s *APIServer) deleteCertAuthority(auth ClientI, w http.ResponseWriter, r *
return message(fmt.Sprintf("cert '%v' deleted", id)), nil
}

type validateOIDCAuthCallbackReq struct {
Query url.Values `json:"query"`
}

// oidcAuthRawResponse is returned when auth server validated callback parameters
// returned from OIDC provider
type oidcAuthRawResponse struct {
// Username is authenticated teleport username
Username string `json:"username"`
// Identity contains validated OIDC identity
Identity types.ExternalIdentity `json:"identity"`
// Web session will be generated by auth server if requested in OIDCAuthRequest
Session json.RawMessage `json:"session,omitempty"`
// Cert will be generated by certificate authority
Cert []byte `json:"cert,omitempty"`
// TLSCert is PEM encoded TLS certificate
TLSCert []byte `json:"tls_cert,omitempty"`
// Req is original oidc auth request
Req OIDCAuthRequest `json:"req"`
// HostSigners is a list of signing host public keys
// trusted by proxy, used in console login
HostSigners []json.RawMessage `json:"host_signers"`
}

func (s *APIServer) validateOIDCAuthCallback(auth ClientI, w http.ResponseWriter, r *http.Request, p httprouter.Params, version string) (interface{}, error) {
var req *validateOIDCAuthCallbackReq
var req *ValidateOIDCAuthCallbackReq
if err := httplib.ReadJSON(r, &req); err != nil {
return nil, trace.Wrap(err)
}
response, err := auth.ValidateOIDCAuthCallback(r.Context(), req.Query)
if err != nil {
return nil, trace.Wrap(err)
}
raw := oidcAuthRawResponse{
raw := OIDCAuthRawResponse{
Username: response.Username,
Identity: response.Identity,
Cert: response.Cert,
Expand Down
55 changes: 19 additions & 36 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,12 @@ import (
"math/big"
insecurerand "math/rand"
"net"
"net/url"
"os"
"strings"
"sync"
"time"

"github.com/coreos/go-oidc/jose"
"github.com/coreos/go-oidc/oauth2"
"github.com/coreos/go-oidc/oidc"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
Expand Down Expand Up @@ -244,7 +241,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
Authority: cfg.Authority,
AuthServiceName: cfg.AuthServiceName,
ServerID: cfg.HostUUID,
oidcClients: make(map[string]*oidcClient),
githubClients: make(map[string]*githubClient),
cancelFunc: cancelFunc,
closeCtx: closeCtx,
Expand All @@ -254,7 +250,6 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
Services: services,
Cache: services,
keyStore: keyStore,
getClaimsFun: getClaims,
inventory: inventory.NewController(cfg.Presence),
traceClient: cfg.TraceClient,
fips: cfg.FIPS,
Expand Down Expand Up @@ -291,6 +286,15 @@ func NewServer(cfg *InitConfig, opts ...ServerOption) (*Server, error) {
}
}

oas, err := NewOIDCAuthService(&OIDCAuthServiceConfig{
Auth: &as,
Emitter: as.emitter,
})
if err != nil {
return nil, trace.Wrap(err)
}
as.SetOIDCService(oas)

return &as, nil
}

Expand Down Expand Up @@ -394,9 +398,7 @@ var (
// - same for users and their sessions
// - checks public keys to see if they're signed by it (can be trusted or not)
type Server struct {
lock sync.RWMutex
// oidcClients is a map from authID & proxyAddr -> oidcClient
oidcClients map[string]*oidcClient
lock sync.RWMutex
githubClients map[string]*githubClient
clock clockwork.Clock
bk backend.Backend
Expand All @@ -405,6 +407,7 @@ type Server struct {
cancelFunc context.CancelFunc

samlAuthService SAMLService
oidcAuthService OIDCService

sshca.Authority

Expand Down Expand Up @@ -455,9 +458,6 @@ type Server struct {
// lockWatcher is a lock watcher, used to verify cert generation requests.
lockWatcher *services.LockWatcher

// getClaimsFun is used in tests for overriding the implementation of getClaims method used in OIDC.
getClaimsFun func(closeCtx context.Context, oidcClient *oidc.Client, connector types.OIDCConnector, code string) (jose.Claims, error)

inventory *inventory.Controller

// githubOrgSSOCache is used to cache whether Github organizations use
Expand All @@ -483,13 +483,20 @@ type Server struct {
loadAllCAs bool
}

// SetSAMLService registers ss as the SAMLService that provides the SAML
// SetSAMLService registers svc as the SAMLService that provides the SAML
// connector implementation. If a SAMLService has already been registered, this
// will override the previous registration.
func (a *Server) SetSAMLService(svc SAMLService) {
a.samlAuthService = svc
}

// SetOIDCService registers svc as the OIDCService that provides the OIDC
// connector implementation. If a OIDCService has already been registered, this
// will override the previous registration.
func (a *Server) SetOIDCService(svc OIDCService) {
a.oidcAuthService = svc
}

func (a *Server) CloseContext() context.Context {
return a.closeCtx
}
Expand Down Expand Up @@ -3990,17 +3997,6 @@ const (
SessionTokenBytes = 32
)

// oidcClient is internal structure that stores OIDC client and its config
type oidcClient struct {
client *oidc.Client
connector types.OIDCConnector
// syncCtx controls the provider sync goroutine.
syncCtx context.Context
syncCancel context.CancelFunc
// firstSync will be closed once the first provider sync succeeds
firstSync chan struct{}
}

// githubClient is internal structure that stores Github OAuth 2client and its config
type githubClient struct {
client *oauth2.Client
Expand Down Expand Up @@ -4038,19 +4034,6 @@ func oauth2ConfigsEqual(a, b oauth2.Config) bool {
return true
}

// isHTTPS checks if the scheme for a URL is https or not.
func isHTTPS(u string) error {
earl, err := url.Parse(u)
if err != nil {
return trace.Wrap(err)
}
if earl.Scheme != "https" {
return trace.BadParameter("expected scheme https, got %q", earl.Scheme)
}

return nil
}

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
Expand Down
103 changes: 0 additions & 103 deletions lib/auth/auth_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@ import (
"crypto/rand"
"crypto/rsa"
"crypto/x509/pkix"
"encoding/json"
"errors"
"fmt"
mathrand "math/rand"
Expand All @@ -30,7 +29,6 @@ import (
"testing"
"time"

"github.com/coreos/go-oidc/jose"
"github.com/google/go-cmp/cmp"
"github.com/google/uuid"
reporting "github.com/gravitational/reporting/types"
Expand Down Expand Up @@ -754,107 +752,6 @@ func TestGenerateTokenEventsEmitted(t *testing.T) {
require.Equal(t, s.mockEmitter.LastEvent().GetType(), events.TrustedClusterTokenCreateEvent)
}

func TestValidateACRValues(t *testing.T) {
s := newAuthSuite(t)

tests := []struct {
comment string
inIDToken string
inACRValue string
inACRProvider string
outIsValid require.ErrorAssertionFunc
}{
{
"0 - default, acr values match",
`
{
"acr": "foo",
"aud": "00000000-0000-0000-0000-000000000000",
"exp": 1111111111
}
`,
"foo",
"",
require.NoError,
},
{
"1 - default, acr values do not match",
`
{
"acr": "foo",
"aud": "00000000-0000-0000-0000-000000000000",
"exp": 1111111111
}
`,
"bar",
"",
require.Error,
},
{
"2 - netiq, acr values match",
`
{
"acr": {
"values": [
"foo/bar/baz"
]
},
"aud": "00000000-0000-0000-0000-000000000000",
"exp": 1111111111
}
`,
"foo/bar/baz",
"netiq",
require.NoError,
},
{
"3 - netiq, invalid format",
`
{
"acr": {
"values": "foo/bar/baz"
},
"aud": "00000000-0000-0000-0000-000000000000",
"exp": 1111111111
}
`,
"foo/bar/baz",
"netiq",
require.Error,
},
{
"4 - netiq, invalid value",
`
{
"acr": {
"values": [
"foo/bar/baz/qux"
]
},
"aud": "00000000-0000-0000-0000-000000000000",
"exp": 1111111111
}
`,
"foo/bar/baz",
"netiq",
require.Error,
},
}

for _, tt := range tests {
tt := tt
t.Run(tt.comment, func(t *testing.T) {
t.Parallel()
var claims jose.Claims
err := json.Unmarshal([]byte(tt.inIDToken), &claims)
require.NoError(t, err)

err = s.a.validateACRValues(tt.inACRValue, tt.inACRProvider, claims)
tt.outIsValid(t, err)
})
}
}

func TestUpdateConfig(t *testing.T) {
t.Parallel()
s := newAuthSuite(t)
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/clt.go
Original file line number Diff line number Diff line change
Expand Up @@ -873,13 +873,13 @@ func (c *Client) GenerateHostCert(

// ValidateOIDCAuthCallback validates OIDC auth callback returned from redirect
func (c *Client) ValidateOIDCAuthCallback(ctx context.Context, q url.Values) (*OIDCAuthResponse, error) {
out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), validateOIDCAuthCallbackReq{
out, err := c.PostJSON(ctx, c.Endpoint("oidc", "requests", "validate"), ValidateOIDCAuthCallbackReq{
Query: q,
})
if err != nil {
return nil, trace.Wrap(err)
}
var rawResponse *oidcAuthRawResponse
var rawResponse *OIDCAuthRawResponse
if err := json.Unmarshal(out.Bytes(), &rawResponse); err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading