Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions api/proto/teleport/legacy/types/types.proto
Original file line number Diff line number Diff line change
Expand Up @@ -462,6 +462,8 @@ message AWS {
// If not, the user must update the AWS profile identity to allow access to the Database.
// Eg for an RDS Database: the underlying AWS profile allows for `rds-db:connect` for the Database.
IAMPolicyStatus IAMPolicyStatus = 14 [(gogoproto.jsontag) = "iam_policy_status"];
// SessionTags is a list of AWS STS session tags.
map<string, string> SessionTags = 15 [(gogoproto.jsontag) = "session_tags,omitempty"];
}

// SecretStore contains secret store configurations.
Expand Down
3,188 changes: 1,674 additions & 1,514 deletions api/types/types.pb.go

Large diffs are not rendered by default.

12 changes: 12 additions & 0 deletions lib/config/configuration.go
Original file line number Diff line number Diff line change
Expand Up @@ -164,6 +164,8 @@ type CommandLineFlags struct {
DatabaseAWSElastiCacheGroupID string
// DatabaseAWSMemoryDBClusterName is the MemoryDB cluster name.
DatabaseAWSMemoryDBClusterName string
// DatabaseAWSSessionTags is the AWS STS session tags.
DatabaseAWSSessionTags string
// DatabaseGCPProjectID is GCP Cloud SQL project identifier.
DatabaseGCPProjectID string
// DatabaseGCPInstanceID is GCP Cloud SQL instance identifier.
Expand Down Expand Up @@ -1667,6 +1669,7 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error {
AssumeRoleARN: database.AWS.AssumeRoleARN,
ExternalID: database.AWS.ExternalID,
Region: database.AWS.Region,
SessionTags: database.AWS.SessionTags,
Redshift: servicecfg.DatabaseAWSRedshift{
ClusterID: database.AWS.Redshift.ClusterID,
},
Expand Down Expand Up @@ -2234,6 +2237,14 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo
return trace.Wrap(err)
}
}
var sessionTags map[string]string
if clf.DatabaseAWSSessionTags != "" {
var err error
sessionTags, err = client.ParseLabelSpec(clf.DatabaseAWSSessionTags)
if err != nil {
return trace.Wrap(err)
}
}
db := servicecfg.Database{
Name: clf.DatabaseName,
Description: clf.DatabaseDescription,
Expand All @@ -2252,6 +2263,7 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo
AccountID: clf.DatabaseAWSAccountID,
AssumeRoleARN: clf.DatabaseAWSAssumeRoleARN,
ExternalID: clf.DatabaseAWSExternalID,
SessionTags: sessionTags,
Redshift: servicecfg.DatabaseAWSRedshift{
ClusterID: clf.DatabaseAWSRedshiftClusterID,
},
Expand Down
35 changes: 35 additions & 0 deletions lib/config/configuration_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2892,6 +2892,41 @@ func TestDatabaseCLIFlags(t *testing.T) {
},
},
},
{
desc: "AWS DynamoDB with session tags",
inFlags: CommandLineFlags{
DatabaseName: "ddb",
DatabaseProtocol: defaults.ProtocolDynamoDB,
DatabaseURI: "dynamodb.us-east-1.amazonaws.com",
DatabaseAWSAccountID: "123456789012",
DatabaseAWSRegion: "us-east-1",
DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer",
DatabaseAWSExternalID: "externalID123",
DatabaseAWSSessionTags: "database_name=hello,something=else",
},
outDatabase: servicecfg.Database{
Name: "ddb",
Protocol: defaults.ProtocolDynamoDB,
URI: "dynamodb.us-east-1.amazonaws.com",
AWS: servicecfg.DatabaseAWS{
Region: "us-east-1",
AccountID: "123456789012",
AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer",
ExternalID: "externalID123",
SessionTags: map[string]string{
"database_name": "hello",
"something": "else",
},
},
StaticLabels: map[string]string{
types.OriginLabel: types.OriginConfigFile,
},
DynamicLabels: services.CommandLabels{},
TLS: servicecfg.DatabaseTLS{
Mode: servicecfg.VerifyFull,
},
},
},
}

for _, tt := range tests {
Expand Down
2 changes: 2 additions & 0 deletions lib/config/fileconf.go
Original file line number Diff line number Diff line change
Expand Up @@ -1788,6 +1788,8 @@ type DatabaseAWS struct {
ExternalID string `yaml:"external_id,omitempty"`
// RedshiftServerless contains RedshiftServerless specific settings.
RedshiftServerless DatabaseAWSRedshiftServerless `yaml:"redshift_serverless"`
// SessionTags is a list of AWS STS session tags.
SessionTags map[string]string `yaml:"session_tags,omitempty"`
}

// DatabaseAWSRedshift contains AWS Redshift specific settings.
Expand Down
2 changes: 1 addition & 1 deletion lib/configurators/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,7 @@ var (
}
// dynamodbActions contains IAM actions for static AWS DynamoDB databases.
dynamodbActions = databaseActions{
authBoundary: stsActions,
authBoundary: append(stsActions, "sts:TagSession"),
}
// opensearchActions contains IAM actions for types.AWSMatcherOpenSearch
opensearchActions = databaseActions{
Expand Down
2 changes: 1 addition & 1 deletion lib/configurators/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -761,7 +761,7 @@ func TestAWSIAMDocuments(t *testing.T) {
{
Effect: awslib.EffectAllow,
Resources: awslib.SliceOrString{"*"},
Actions: awslib.SliceOrString{"sts:AssumeRole"},
Actions: awslib.SliceOrString{"sts:AssumeRole", "sts:TagSession"},
},
},
},
Expand Down
3 changes: 3 additions & 0 deletions lib/service/servicecfg/database.go
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ func (d *Database) ToDatabase() (types.Database, error) {
AssumeRoleARN: d.AWS.AssumeRoleARN,
ExternalID: d.AWS.ExternalID,
Region: d.AWS.Region,
SessionTags: d.AWS.SessionTags,
Redshift: types.Redshift{
ClusterID: d.AWS.Redshift.ClusterID,
},
Expand Down Expand Up @@ -260,6 +261,8 @@ type DatabaseAWS struct {
ExternalID string
// RedshiftServerless contains AWS Redshift Serverless specific settings.
RedshiftServerless DatabaseAWSRedshiftServerless
// SessionTags is a list of AWS STS session tags.
SessionTags map[string]string
}

// DatabaseAWSRedshift contains AWS Redshift specific settings.
Expand Down
1 change: 1 addition & 0 deletions lib/srv/db/dynamodb/engine.go
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,7 @@ func (e *Engine) process(ctx context.Context, req *http.Request, signer *libaws.
Expiry: e.sessionCtx.Identity.Expires,
SessionName: e.sessionCtx.Identity.Username,
AWSRoleArn: roleArn,
SessionTags: e.sessionCtx.Database.GetAWS().SessionTags,
}
if meta.AssumeRoleARN == "" {
signingCtx.AWSExternalID = meta.ExternalID
Expand Down
43 changes: 42 additions & 1 deletion lib/utils/aws/credentials.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,15 @@ package aws

import (
"context"
"sort"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/client"
"github.com/aws/aws-sdk-go/aws/credentials"
"github.com/aws/aws-sdk-go/aws/credentials/stscreds"
"github.com/aws/aws-sdk-go/service/sts"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/sirupsen/logrus"
Expand All @@ -40,6 +43,8 @@ type GetCredentialsRequest struct {
RoleARN string
// ExternalID is the external ID to be requested, if not empty.
ExternalID string
// Tags is a list of AWS STS session tags.
Tags map[string]string
}

// CredentialsGetter defines an interface for obtaining STS credentials.
Expand Down Expand Up @@ -67,6 +72,11 @@ func (g *credentialsGetter) Get(_ context.Context, request GetCredentialsRequest
if request.ExternalID != "" {
cred.ExternalID = aws.String(request.ExternalID)
}

cred.Tags = make([]*sts.Tag, 0, len(request.Tags))
for key, value := range request.Tags {
cred.Tags = append(cred.Tags, &sts.Tag{Key: aws.String(key), Value: aws.String(value)})
}
},
), nil
}
Expand Down Expand Up @@ -94,6 +104,37 @@ func (c *CachedCredentialsGetterConfig) SetDefaults() {
}
}

// credentialRequestCacheKey credentials request cache key.
type credentialRequestCacheKey struct {
provider client.ConfigProvider
expiry time.Time
sessionName string
roleARN string
externalID string
tags string
}

// newCredentialRequestCacheKey creates a new cache key for the credentials
// request.
func newCredentialRequestCacheKey(req GetCredentialsRequest) credentialRequestCacheKey {
k := credentialRequestCacheKey{
provider: req.Provider,
expiry: req.Expiry,
sessionName: req.SessionName,
roleARN: req.RoleARN,
externalID: req.ExternalID,
}

tags := make([]string, 0, len(req.Tags))
for key, value := range req.Tags {
tags = append(tags, key+"="+value+",")
}
sort.Strings(tags)
k.tags = strings.Join(tags, ",")

return k
}

type cachedCredentialsGetter struct {
config CachedCredentialsGetterConfig
cache *utils.FnCache
Expand All @@ -120,7 +161,7 @@ func NewCachedCredentialsGetter(config CachedCredentialsGetterConfig) (Credentia
// Get returns cached credentials if found, or fetch it from the configured
// getter.
func (g *cachedCredentialsGetter) Get(ctx context.Context, request GetCredentialsRequest) (*credentials.Credentials, error) {
credentials, err := utils.FnCacheGet(ctx, g.cache, request, func(ctx context.Context) (*credentials.Credentials, error) {
credentials, err := utils.FnCacheGet(ctx, g.cache, newCredentialRequestCacheKey(request), func(ctx context.Context) (*credentials.Credentials, error) {
credentials, err := g.config.Getter.Get(ctx, request)
return credentials, trace.Wrap(err)
})
Expand Down
85 changes: 69 additions & 16 deletions lib/utils/aws/credentials_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,11 @@ func TestCachedCredentialsGetter(t *testing.T) {
Expiry: fakeClock.Now().Add(time.Hour),
SessionName: "test-session",
RoleARN: "test-role",
Tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
})
require.NoError(t, err)
checkCredentialsAccessKeyID(t, cred1, "test-session-test-role-")
Expand All @@ -68,38 +73,85 @@ func TestCachedCredentialsGetter(t *testing.T) {
sessionName string
roleARN string
externalID string
tags map[string]string
advanceClock time.Duration
compareCred1 require.ComparisonAssertionFunc
}{
{
name: "cached",
sessionName: "test-session",
roleARN: "test-role",
name: "cached",
sessionName: "test-session",
roleARN: "test-role",
tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
compareCred1: require.Same,
},
{
name: "cached different tags order",
sessionName: "test-session",
roleARN: "test-role",
tags: map[string]string{
"three": "3",
"two": "2",
"one": "1",
},
compareCred1: require.Same,
},
{
name: "different session name",
sessionName: "test-session-2",
roleARN: "test-role",
name: "different session name",
sessionName: "test-session-2",
roleARN: "test-role",
tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
compareCred1: require.NotSame,
},
{
name: "different role ARN",
sessionName: "test-session",
roleARN: "test-role-2",
tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
compareCred1: require.NotSame,
},
{
name: "different role ARN",
sessionName: "test-session",
roleARN: "test-role-2",
name: "different external ID",
sessionName: "test-session",
roleARN: "test-role",
externalID: "test-id",
tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
compareCred1: require.NotSame,
},
{
name: "different external ID",
sessionName: "test-session",
roleARN: "test-role",
externalID: "test-id",
name: "different tags",
sessionName: "test-session",
roleARN: "test-role",
tags: map[string]string{
"four": "4",
"five": "5",
},
compareCred1: require.NotSame,
},
{
name: "cache expired",
sessionName: "test-session",
roleARN: "test-role",
name: "cache expired",
sessionName: "test-session",
roleARN: "test-role",
tags: map[string]string{
"one": "1",
"two": "2",
"three": "3",
},
advanceClock: time.Hour,
compareCred1: require.NotSame,
},
Expand All @@ -115,6 +167,7 @@ func TestCachedCredentialsGetter(t *testing.T) {
SessionName: test.sessionName,
RoleARN: test.roleARN,
ExternalID: test.externalID,
Tags: test.tags,
})
require.NoError(t, err)
checkCredentialsAccessKeyID(t, cred, fmt.Sprintf("%s-%s-%s", test.sessionName, test.roleARN, test.externalID))
Expand Down
3 changes: 3 additions & 0 deletions lib/utils/aws/signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,8 @@ type SigningCtx struct {
AWSRoleArn string
// AWSExternalID is an optional external ID used when getting sts credentials.
AWSExternalID string
// SessionTags is a list of AWS STS session tags.
SessionTags map[string]string
}

// Check checks signing context parameters.
Expand Down Expand Up @@ -167,6 +169,7 @@ func (s *SigningService) SignRequest(ctx context.Context, req *http.Request, sig
SessionName: signCtx.SessionName,
RoleARN: signCtx.AWSRoleArn,
ExternalID: signCtx.AWSExternalID,
Tags: signCtx.SessionTags,
})
if err != nil {
return nil, trace.Wrap(err)
Expand Down
1 change: 1 addition & 0 deletions tool/teleport/common/teleport.go
Original file line number Diff line number Diff line change
Expand Up @@ -245,6 +245,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con
dbStartCmd.Flag("aws-redshift-cluster-id", "(Only for Redshift) Redshift database cluster identifier.").StringVar(&ccf.DatabaseAWSRedshiftClusterID)
dbStartCmd.Flag("aws-rds-instance-id", "(Only for RDS) RDS instance identifier.").StringVar(&ccf.DatabaseAWSRDSInstanceID)
dbStartCmd.Flag("aws-rds-cluster-id", "(Only for Aurora) Aurora cluster identifier.").StringVar(&ccf.DatabaseAWSRDSClusterID)
dbStartCmd.Flag("aws-session-tags", "(Only for DynamoDB) List of STS tags.").StringVar(&ccf.DatabaseAWSSessionTags)
dbStartCmd.Flag("gcp-project-id", "(Only for Cloud SQL) GCP Cloud SQL project identifier.").StringVar(&ccf.DatabaseGCPProjectID)
dbStartCmd.Flag("gcp-instance-id", "(Only for Cloud SQL) GCP Cloud SQL instance identifier.").StringVar(&ccf.DatabaseGCPInstanceID)
dbStartCmd.Flag("ad-keytab-file", "(Only for SQL Server) Kerberos keytab file.").StringVar(&ccf.DatabaseADKeytabFile)
Expand Down