Skip to content

Commit

Permalink
Split auth.AccessPoint into variant specific interfaces (#8471)
Browse files Browse the repository at this point in the history
(cherry picked from commit 5cd7c3c)
  • Loading branch information
rosstimothy committed Nov 5, 2021
1 parent f3962c4 commit 88ea7e8
Show file tree
Hide file tree
Showing 49 changed files with 1,047 additions and 480 deletions.
909 changes: 728 additions & 181 deletions lib/auth/api.go

Large diffs are not rendered by default.

28 changes: 13 additions & 15 deletions lib/auth/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2688,32 +2688,32 @@ func (a *Server) GetToken(ctx context.Context, token string) (types.ProvisionTok
return a.GetCache().GetToken(ctx, token)
}

// GetRoles is a part of auth.AccessPoint implementation
// GetRoles returns roles from the cache
func (a *Server) GetRoles(ctx context.Context) ([]types.Role, error) {
return a.GetCache().GetRoles(ctx)
}

// GetRole is a part of auth.AccessPoint implementation
// GetRole returns a role from the cache
func (a *Server) GetRole(ctx context.Context, name string) (types.Role, error) {
return a.GetCache().GetRole(ctx, name)
}

// GetNamespace returns namespace
// GetNamespace returns a namespace from the cache
func (a *Server) GetNamespace(name string) (*types.Namespace, error) {
return a.GetCache().GetNamespace(name)
}

// GetNamespaces is a part of auth.AccessPoint implementation
// GetNamespaces returns namespaces from the cache
func (a *Server) GetNamespaces() ([]types.Namespace, error) {
return a.GetCache().GetNamespaces()
}

// GetNodes is a part of auth.AccessPoint implementation
// GetNodes returns nodes from the cache
func (a *Server) GetNodes(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
return a.GetCache().GetNodes(ctx, namespace, opts...)
}

// ListNodes is a part of auth.AccessPoint implementation
// ListNodes lists nodes from the cache
func (a *Server) ListNodes(ctx context.Context, req proto.ListNodesRequest) ([]types.Server, string, error) {
return a.GetCache().ListNodes(ctx, req)
}
Expand Down Expand Up @@ -2744,34 +2744,32 @@ func (a *Server) IterateNodePages(ctx context.Context, req proto.ListNodesReques
}
}

// GetReverseTunnels is a part of auth.AccessPoint implementation
// GetReverseTunnels returns reverse tunnels from the cache
func (a *Server) GetReverseTunnels(opts ...services.MarshalOption) ([]types.ReverseTunnel, error) {
return a.GetCache().GetReverseTunnels(opts...)
}

// GetProxies is a part of auth.AccessPoint implementation
// GetProxies returns proxies from the cache
func (a *Server) GetProxies() ([]types.Server, error) {
return a.GetCache().GetProxies()
}

// GetUser is a part of auth.AccessPoint implementation.
// GetUser returns a user from the cache
func (a *Server) GetUser(name string, withSecrets bool) (user types.User, err error) {
return a.GetCache().GetUser(name, withSecrets)
}

// GetUsers is a part of auth.AccessPoint implementation
// GetUsers returns users from the cache
func (a *Server) GetUsers(withSecrets bool) (users []types.User, err error) {
return a.GetCache().GetUsers(withSecrets)
}

// GetTunnelConnections is a part of auth.AccessPoint implementation
// GetTunnelConnections are not using recent cache as they are designed
// to be called periodically and always return fresh data
func (a *Server) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) {
return a.GetCache().GetTunnelConnections(clusterName, opts...)
}

// GetAllTunnelConnections is a part of auth.AccessPoint implementation
// GetAllTunnelConnections are not using recent cache, as they are designed
// to be called periodically and always return fresh data
func (a *Server) GetAllTunnelConnections(opts ...services.MarshalOption) (conns []types.TunnelConnection, err error) {
Expand Down Expand Up @@ -2814,12 +2812,12 @@ func (a *Server) modeStreamer(ctx context.Context) (events.Streamer, error) {
return a.streamer, nil
}

// GetAppServers is a part of the auth.AccessPoint implementation.
// GetAppServers returns app servers from the cache
func (a *Server) GetAppServers(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
return a.GetCache().GetAppServers(ctx, namespace, opts...)
}

// GetAppSession is a part of the auth.AccessPoint implementation.
// GetAppSession returns app sessions from the cache
func (a *Server) GetAppSession(ctx context.Context, req types.GetAppSessionRequest) (types.WebSession, error) {
return a.GetCache().GetAppSession(ctx, req)
}
Expand Down Expand Up @@ -3707,7 +3705,7 @@ func isHTTPS(u string) error {

// WithClusterCAs returns a TLS hello callback that returns a copy of the provided
// TLS config with client CAs pool of the specified cluster.
func WithClusterCAs(tlsConfig *tls.Config, ap AccessPoint, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
func WithClusterCAs(tlsConfig *tls.Config, ap AccessCache, currentClusterName string, log logrus.FieldLogger) func(*tls.ClientHelloInfo) (*tls.Config, error) {
return func(info *tls.ClientHelloInfo) (*tls.Config, error) {
var clusterName string
var err error
Expand Down
37 changes: 32 additions & 5 deletions lib/auth/permissions.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ func NewBuiltinRoleContext(role types.SystemRole) (*Context, error) {
}

// NewAuthorizer returns new authorizer using backends
func NewAuthorizer(clusterName string, accessPoint ReadAccessPoint, lockWatcher *services.LockWatcher) (Authorizer, error) {
func NewAuthorizer(clusterName string, accessPoint AuthorizerAccessPoint, lockWatcher *services.LockWatcher) (Authorizer, error) {
if clusterName == "" {
return nil, trace.BadParameter("missing parameter clusterName")
}
Expand All @@ -68,16 +68,43 @@ type Authorizer interface {
Authorize(ctx context.Context) (*Context, error)
}

// AuthorizerAccessPoint is the access point contract required by an Authorizer
type AuthorizerAccessPoint interface {
// GetAuthPreference returns the cluster authentication configuration.
GetAuthPreference(ctx context.Context) (types.AuthPreference, error)

// GetRole returns role by name
GetRole(ctx context.Context, name string) (types.Role, error)

// GetUser returns a services.User for this cluster.
GetUser(name string, withSecrets bool) (types.User, error)

// GetCertAuthority returns cert authority by id
GetCertAuthority(id types.CertAuthID, loadKeys bool, opts ...services.MarshalOption) (types.CertAuthority, error)

// GetCertAuthorities returns a list of cert authorities
GetCertAuthorities(caType types.CertAuthType, loadKeys bool, opts ...services.MarshalOption) ([]types.CertAuthority, error)

// GetClusterAuditConfig returns cluster audit configuration.
GetClusterAuditConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterAuditConfig, error)

// GetClusterNetworkingConfig returns cluster networking configuration.
GetClusterNetworkingConfig(ctx context.Context, opts ...services.MarshalOption) (types.ClusterNetworkingConfig, error)

// GetSessionRecordingConfig returns session recording configuration.
GetSessionRecordingConfig(ctx context.Context, opts ...services.MarshalOption) (types.SessionRecordingConfig, error)
}

// authorizer creates new local authorizer
type authorizer struct {
clusterName string
accessPoint ReadAccessPoint
accessPoint AuthorizerAccessPoint
lockWatcher *services.LockWatcher
}

// Context is authorization context
type Context struct {
// User is the user name
// User is the username
User types.User
// Checker is access checker
Checker services.AccessChecker
Expand Down Expand Up @@ -669,7 +696,7 @@ func contextForBuiltinRole(r BuiltinRole, recConfig types.SessionRecordingConfig
}, nil
}

func contextForLocalUser(u LocalUser, accessPoint ReadAccessPoint) (*Context, error) {
func contextForLocalUser(u LocalUser, accessPoint AuthorizerAccessPoint) (*Context, error) {
// User has to be fetched to check if it's a blocked username
user, err := accessPoint.GetUser(u.Username, false)
if err != nil {
Expand All @@ -684,7 +711,7 @@ func contextForLocalUser(u LocalUser, accessPoint ReadAccessPoint) (*Context, er
return nil, trace.Wrap(err)
}
// Override roles and traits from the local user based on the identity roles
// and traits, this is done to prevent potential conflict. Imagine a scenairo
// and traits, this is done to prevent potential conflict. Imagine a scenario
// when SSO user has left the company, but local user entry remained with old
// privileged roles. New user with the same name has been onboarded and would
// have derived the roles from the stale user entry. This code prevents
Expand Down
2 changes: 1 addition & 1 deletion lib/auth/sessions.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ func (s *Server) CreateAppSession(ctx context.Context, req types.CreateAppSessio

// WaitForAppSession will block until the requested application session shows up in the
// cache or a timeout occurs.
func WaitForAppSession(ctx context.Context, sessionID, user string, ap AccessPoint) error {
func WaitForAppSession(ctx context.Context, sessionID, user string, ap ReadProxyAccessPoint) error {
_, err := ap.GetAppSession(ctx, types.GetAppSessionRequest{SessionID: sessionID})
if err == nil {
return nil
Expand Down
4 changes: 2 additions & 2 deletions lib/auth/tls_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -1532,7 +1532,7 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check
c.Assert(err, check.IsNil)

// Roles extracted from cert should contain the initial role and the role assigned with access request.
roles, _, err := services.ExtractFromCertificate(clt, sshcert)
roles, _, err := services.ExtractFromCertificate(sshcert)
c.Assert(err, check.IsNil)
c.Assert(roles, check.HasLen, 2)

Expand Down Expand Up @@ -1560,7 +1560,7 @@ func (s *TLSSuite) TestWebSessionWithApprovedAccessRequestAndSwitchback(c *check
sshcert, err = sshutils.ParseCertificate(sess2.GetPub())
c.Assert(err, check.IsNil)

roles, _, err = services.ExtractFromCertificate(clt, sshcert)
roles, _, err = services.ExtractFromCertificate(sshcert)
c.Assert(err, check.IsNil)
c.Assert(roles, check.DeepEquals, []string{initialRole})
}
Expand Down
31 changes: 13 additions & 18 deletions lib/cache/cache.go
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,6 @@ func ForNode(cfg Config) Config {
{Kind: types.KindClusterNetworkingConfig},
{Kind: types.KindClusterAuthPreference},
{Kind: types.KindSessionRecordingConfig},
{Kind: types.KindUser},
{Kind: types.KindRole},
// Node only needs to "know" about default
// namespace events to avoid matching too much
Expand Down Expand Up @@ -282,7 +281,7 @@ func ForWindowsDesktop(cfg Config) Config {
// for cache
type SetupConfigFn func(c Config) Config

// Cache implements auth.AccessPoint interface and remembers
// Cache implements auth.Cache interface and remembers
// the previously returned upstream value for each API call.
//
// This which can be used if the upstream AccessPoint goes offline
Expand Down Expand Up @@ -1197,7 +1196,7 @@ func (c *Cache) GetClusterName(opts ...services.MarshalOption) (types.ClusterNam
return rg.clusterConfig.GetClusterName(opts...)
}

// GetRoles is a part of auth.AccessPoint implementation
// GetRoles is a part of auth.Cache implementation
func (c *Cache) GetRoles(ctx context.Context) ([]types.Role, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1207,7 +1206,7 @@ func (c *Cache) GetRoles(ctx context.Context) ([]types.Role, error) {
return rg.access.GetRoles(ctx)
}

// GetRole is a part of auth.AccessPoint implementation
// GetRole is a part of auth.Cache implementation
func (c *Cache) GetRole(ctx context.Context, name string) (types.Role, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1237,7 +1236,7 @@ func (c *Cache) GetNamespace(name string) (*types.Namespace, error) {
return rg.presence.GetNamespace(name)
}

// GetNamespaces is a part of auth.AccessPoint implementation
// GetNamespaces is a part of auth.Cache implementation
func (c *Cache) GetNamespaces() ([]types.Namespace, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1263,7 +1262,7 @@ type getNodesCacheKey struct {

var _ map[getNodesCacheKey]struct{} // compile-time hashability check

// GetNodes is a part of auth.AccessPoint implementation
// GetNodes is a part of auth.Cache implementation
func (c *Cache) GetNodes(ctx context.Context, namespace string, opts ...services.MarshalOption) ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1295,7 +1294,7 @@ func (c *Cache) GetNodes(ctx context.Context, namespace string, opts ...services
return rg.presence.GetNodes(ctx, namespace, opts...)
}

// ListNodes is a part of auth.AccessPoint implementation
// ListNodes is a part of auth.Cache implementation
func (c *Cache) ListNodes(ctx context.Context, req proto.ListNodesRequest) ([]types.Server, string, error) {
// NOTE: we "fake" the ListNodes API here in order to take advantage of TTL-based caching of
// the GetNodes endpoint, since performing TTL-based caching on a paginated endpoint is nightmarish.
Expand Down Expand Up @@ -1353,7 +1352,7 @@ func (c *Cache) GetAuthServers() ([]types.Server, error) {
return rg.presence.GetAuthServers()
}

// GetReverseTunnels is a part of auth.AccessPoint implementation
// GetReverseTunnels is a part of auth.Cache implementation
func (c *Cache) GetReverseTunnels(opts ...services.MarshalOption) ([]types.ReverseTunnel, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1363,7 +1362,7 @@ func (c *Cache) GetReverseTunnels(opts ...services.MarshalOption) ([]types.Rever
return rg.presence.GetReverseTunnels(opts...)
}

// GetProxies is a part of auth.AccessPoint implementation
// GetProxies is a part of auth.Cache implementation
func (c *Cache) GetProxies() ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down Expand Up @@ -1431,7 +1430,7 @@ func (c *Cache) GetRemoteCluster(clusterName string) (types.RemoteCluster, error
return rg.presence.GetRemoteCluster(clusterName)
}

// GetUser is a part of auth.AccessPoint implementation.
// GetUser is a part of auth.Cache implementation.
func (c *Cache) GetUser(name string, withSecrets bool) (user types.User, err error) {
if withSecrets { // cache never tracks user secrets
return c.Config.Users.GetUser(name, withSecrets)
Expand All @@ -1455,7 +1454,7 @@ func (c *Cache) GetUser(name string, withSecrets bool) (user types.User, err err
return user, trace.Wrap(err)
}

// GetUsers is a part of auth.AccessPoint implementation
// GetUsers is a part of auth.Cache implementation
func (c *Cache) GetUsers(withSecrets bool) (users []types.User, err error) {
if withSecrets { // cache never tracks user secrets
return c.Users.GetUsers(withSecrets)
Expand All @@ -1468,9 +1467,7 @@ func (c *Cache) GetUsers(withSecrets bool) (users []types.User, err error) {
return rg.users.GetUsers(withSecrets)
}

// GetTunnelConnections is a part of auth.AccessPoint implementation
// GetTunnelConnections are not using recent cache as they are designed
// to be called periodically and always return fresh data
// GetTunnelConnections is a part of auth.Cache implementation
func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.MarshalOption) ([]types.TunnelConnection, error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1480,9 +1477,7 @@ func (c *Cache) GetTunnelConnections(clusterName string, opts ...services.Marsha
return rg.presence.GetTunnelConnections(clusterName, opts...)
}

// GetAllTunnelConnections is a part of auth.AccessPoint implementation
// GetAllTunnelConnections are not using recent cache, as they are designed
// to be called periodically and always return fresh data
// GetAllTunnelConnections is a part of auth.Cache implementation
func (c *Cache) GetAllTunnelConnections(opts ...services.MarshalOption) (conns []types.TunnelConnection, err error) {
rg, err := c.read()
if err != nil {
Expand All @@ -1492,7 +1487,7 @@ func (c *Cache) GetAllTunnelConnections(opts ...services.MarshalOption) (conns [
return rg.presence.GetAllTunnelConnections(opts...)
}

// GetKubeServices is a part of auth.AccessPoint implementation
// GetKubeServices is a part of auth.Cache implementation
func (c *Cache) GetKubeServices(ctx context.Context) ([]types.Server, error) {
rg, err := c.read()
if err != nil {
Expand Down
12 changes: 0 additions & 12 deletions lib/client/api.go
Original file line number Diff line number Diff line change
Expand Up @@ -1167,18 +1167,6 @@ func (tc *TeleportClient) LoadKeyForClusterWithReissue(ctx context.Context, clus
return nil
}

// accessPoint returns access point based on the cache policy
func (tc *TeleportClient) accessPoint(clt auth.AccessPoint, proxyHostPort string, clusterName string) (auth.AccessPoint, error) {
// If no caching policy was set or on Windows (where Teleport does not
// support file locking at the moment), return direct access to the access
// point.
if tc.CachePolicy == nil || runtime.GOOS == constants.WindowsOS {
log.Debugf("not using caching access point")
return clt, nil
}
return clt, nil
}

// LocalAgent is a getter function for the client's local agent
func (tc *TeleportClient) LocalAgent() *LocalKeyAgent {
return tc.localAgent
Expand Down
6 changes: 3 additions & 3 deletions lib/client/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -624,7 +624,7 @@ func (proxy *ProxyClient) GetDatabaseServers(ctx context.Context, namespace stri
// CurrentClusterAccessPoint returns cluster access point to the currently
// selected cluster and is used for discovery
// and could be cached based on the access policy
func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet bool) (auth.AccessPoint, error) {
func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet bool) (auth.ClientI, error) {
// get the current cluster:
cluster, err := proxy.currentCluster()
if err != nil {
Expand All @@ -635,15 +635,15 @@ func (proxy *ProxyClient) CurrentClusterAccessPoint(ctx context.Context, quiet b

// ClusterAccessPoint returns cluster access point used for discovery
// and could be cached based on the access policy
func (proxy *ProxyClient) ClusterAccessPoint(ctx context.Context, clusterName string, quiet bool) (auth.AccessPoint, error) {
func (proxy *ProxyClient) ClusterAccessPoint(ctx context.Context, clusterName string, quiet bool) (auth.ClientI, error) {
if clusterName == "" {
return nil, trace.BadParameter("parameter clusterName is missing")
}
clt, err := proxy.ConnectToCluster(ctx, clusterName, quiet)
if err != nil {
return nil, trace.Wrap(err)
}
return proxy.teleportClient.accessPoint(clt, proxy.proxyAddress, clusterName)
return clt, nil
}

// ConnectToCurrentCluster connects to the auth server of the currently selected
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ type ForwarderConfig struct {
// AuthClient is a auth server client.
AuthClient auth.ClientI
// CachingAuthClient is a caching auth server client for read-only access.
CachingAuthClient auth.AccessPoint
CachingAuthClient auth.ReadKubernetesAccessPoint
// StreamEmitter is used to create audit streams
// and emit audit events
StreamEmitter events.StreamEmitter
Expand Down
2 changes: 1 addition & 1 deletion lib/kube/proxy/forwarder_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -845,7 +845,7 @@ type mockRemoteSite struct {
func (s mockRemoteSite) GetName() string { return s.name }

type mockAccessPoint struct {
auth.AccessPoint
auth.KubernetesAccessPoint

netConfig types.ClusterNetworkingConfig
recordingConfig types.SessionRecordingConfig
Expand Down
Loading

0 comments on commit 88ea7e8

Please sign in to comment.