From 0b75973ed0d1ce4f976eef7b163dad46fadf61d3 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Thu, 16 Mar 2023 14:33:07 -0700 Subject: [PATCH 1/4] improve aws utils and database validation * log a warning when an invalid AWS IAM role ARN is filtered out * add tests for nested role arn path and remove superfluous test parallelism * add extra godoc info for BuildRoleARN * improve mock IAM error messages - Failing tests' error messages were not specific enough to know what happened. * refactor code to call GetAWS() once - Avoids constantly fetching a full copy of the AWS metadata. - Reduces verbosity of code. * reuse MetadataFromRedshiftCluster in meta fetcher - Avoids repeating the same logic inside the fetcher. * improve database-access DynamoDB config checking - Don't early return from database CheckAndSetDefaults so AWS account ID is checked properly, and any other future checks added will apply to DynamoDB as well. * add aws utils for parsing AWS IAM roles - ParseRoleARN will parse an AWS ARN and verify that it is a valid AWS IAM Role resource - CheckARNPartitionAndAccount is a helper that verifies an ARN partition and account ID matches the caller's expectations. * use cmp.Diff instead of require.Equal in db tests - require.Equal spams a lot of protobuf noise. - cmp.Diff yields a precise, readable diff on test failure. --- api/types/database.go | 21 ++-- api/types/database_test.go | 47 ++++----- api/utils/aws/endpoint_test.go | 35 ++++--- lib/cloud/mocks/aws.go | 8 +- lib/services/database_test.go | 47 ++++----- lib/services/role.go | 6 +- lib/srv/app/cloud.go | 4 +- lib/srv/app/server_test.go | 4 +- lib/srv/db/cassandra/handshake.go | 6 +- lib/srv/db/cloud/meta.go | 30 +++--- lib/srv/db/cloud/users/elasticache.go | 9 +- lib/srv/db/cloud/users/helpers.go | 3 +- lib/srv/db/cloud/users/memorydb.go | 7 +- lib/srv/db/common/auth.go | 47 +++------ lib/srv/db/dynamodb/engine.go | 10 +- lib/srv/db/dynamodb_test.go | 2 +- lib/utils/aws/aws.go | 97 +++++++++++++----- lib/utils/aws/aws_test.go | 137 ++++++++++++++++++++++++++ lib/utils/aws/signing.go | 7 +- tool/tsh/app_aws_test.go | 2 +- tool/tsh/db_test.go | 2 +- 21 files changed, 367 insertions(+), 164 deletions(-) diff --git a/api/types/database.go b/api/types/database.go index 0c4ecb7f6dd2b..05dcaa97e6a09 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -418,7 +418,7 @@ func (d *DatabaseV3) getAWSType() (string, bool) { aws := d.GetAWS() switch d.Spec.Protocol { case DatabaseTypeCassandra: - if aws.AccountID != "" { + if !aws.IsEmpty() { return DatabaseTypeAWSKeyspaces, true } case DatabaseTypeDynamoDB: @@ -507,16 +507,19 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { if d.Spec.Protocol == "" { return trace.BadParameter("database %q protocol is empty", d.GetName()) } - if d.IsDynamoDB() { - // DynamoDB gets its own checking logic for its unusual config. - return trace.Wrap(d.handleDynamoDBConfig()) - } if d.Spec.URI == "" { switch { - case d.IsAWSKeyspaces() && d.GetAWS().Region != "": + case d.IsAWSKeyspaces() && d.Spec.AWS.Region != "": // In case of AWS Hosted Cassandra allow to omit URI. // The URL will be constructed from the database resource based on the region and account ID. d.Spec.URI = awsutils.CassandraEndpointURLForRegion(d.Spec.AWS.Region) + case d.IsDynamoDB(): + if d.Spec.AWS.Region != "" { + d.Spec.URI = awsutils.DynamoDBURIForRegion(d.Spec.AWS.Region) + } else { + return trace.BadParameter("DynamoDB database %q URI is missing and cannot be derived from an empty configured AWS region", + d.GetName()) + } default: return trace.BadParameter("database %q URI is empty", d.GetName()) } @@ -529,6 +532,10 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { // In case of RDS, Aurora or Redshift, AWS information such as region or // cluster ID can be extracted from the endpoint if not provided. switch { + case d.IsDynamoDB(): + if err := d.handleDynamoDBConfig(); err != nil { + return trace.Wrap(err) + } case awsutils.IsRDSEndpoint(d.Spec.URI): details, err := awsutils.ParseRDSEndpoint(d.Spec.URI) if err != nil { @@ -697,7 +704,7 @@ func (d *DatabaseV3) handleDynamoDBConfig() error { // so we check if the region is configured to see if this is really a configuration error. if d.Spec.AWS.Region == "" { // the AWS region is empty and we can't derive it from the URI, so this is a config error. - return trace.BadParameter("database %q AWS region is empty and cannot be derived from the URI %q", + return trace.BadParameter("database %q AWS region is missing and cannot be derived from the URI %q", d.GetName(), d.Spec.URI) } if awsutils.IsAWSEndpoint(d.Spec.URI) { diff --git a/api/types/database_test.go b/api/types/database_test.go index de76d3bf2f649..11f6ee4c84ef9 100644 --- a/api/types/database_test.go +++ b/api/types/database_test.go @@ -555,48 +555,48 @@ func TestDynamoDBConfig(t *testing.T) { { desc: "account and region and empty URI is correct", region: "us-west-1", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "aws://dynamodb.us-west-1.amazonaws.com", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, { desc: "account and AWS URI and empty region is correct", uri: "dynamodb.us-west-1.amazonaws.com", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "dynamodb.us-west-1.amazonaws.com", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, { desc: "account and AWS streams dynamodb URI and empty region is correct", uri: "streams.dynamodb.us-west-1.amazonaws.com", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "streams.dynamodb.us-west-1.amazonaws.com", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, { desc: "account and AWS dax URI and empty region is correct", uri: "dax.us-west-1.amazonaws.com", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "dax.us-west-1.amazonaws.com", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, @@ -604,12 +604,12 @@ func TestDynamoDBConfig(t *testing.T) { desc: "account and region and matching AWS URI region is correct", uri: "dynamodb.us-west-1.amazonaws.com", region: "us-west-1", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "dynamodb.us-west-1.amazonaws.com", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, @@ -617,12 +617,12 @@ func TestDynamoDBConfig(t *testing.T) { desc: "account and region and custom URI is correct", uri: "localhost:8080", region: "us-west-1", - account: "12345", + account: "123456789012", wantSpec: DatabaseSpecV3{ URI: "localhost:8080", AWS: AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, }, }, @@ -630,30 +630,33 @@ func TestDynamoDBConfig(t *testing.T) { desc: "region and different AWS URI region is an error", uri: "dynamodb.us-west-2.amazonaws.com", region: "us-west-1", - account: "12345", + account: "123456789012", wantErrMsg: "does not match the configured URI", }, { desc: "invalid AWS URI is an error", uri: "a.streams.dynamodb.us-west-1.amazonaws.com", region: "us-west-1", - account: "12345", + account: "123456789012", wantErrMsg: "invalid DynamoDB endpoint", }, { - desc: "custom URI and empty region is an error", + desc: "custom URI and missing region is an error", uri: "localhost:8080", - account: "12345", - wantErrMsg: "region is empty", + account: "123456789012", + wantErrMsg: "region is missing", }, { - desc: "empty URI and empty region is an error", - account: "12345", - wantErrMsg: "region is empty", + desc: "missing URI and missing region is an error", + account: "123456789012", + wantErrMsg: "URI is missing", }, { - desc: "missing account id", - wantErrMsg: "account ID is empty", + desc: "invalid AWS account ID is an error", + uri: "localhost:8080", + region: "us-west-1", + account: "12345", + wantErrMsg: "must be 12-digit", }, } diff --git a/api/utils/aws/endpoint_test.go b/api/utils/aws/endpoint_test.go index 440658bf54269..cb216af5bbb34 100644 --- a/api/utils/aws/endpoint_test.go +++ b/api/utils/aws/endpoint_test.go @@ -513,30 +513,39 @@ func TestRedshiftServerlessEndpoint(t *testing.T) { func TestDynamoDBURIForRegion(t *testing.T) { t.Parallel() tests := []struct { - desc string - region string - wantURI string + desc string + region string + wantURI string + wantPartition string }{ { - desc: "region is in correct AWS partition", - region: "us-east-1", - wantURI: "aws://dynamodb.us-east-1.amazonaws.com", + desc: "region is in correct AWS partition", + region: "us-east-1", + wantURI: "aws://dynamodb.us-east-1.amazonaws.com", + wantPartition: ".amazonaws.com", }, { - desc: "china north region is in correct AWS partition", - region: "cn-north-1", - wantURI: "aws://dynamodb.cn-north-1.amazonaws.com.cn", + desc: "china north region is in correct AWS partition", + region: "cn-north-1", + wantURI: "aws://dynamodb.cn-north-1.amazonaws.com.cn", + wantPartition: ".amazonaws.com.cn", }, { - desc: "china northwest region is in correct AWS partition", - region: "cn-northwest-1", - wantURI: "aws://dynamodb.cn-northwest-1.amazonaws.com.cn", + desc: "china northwest region is in correct AWS partition", + region: "cn-northwest-1", + wantURI: "aws://dynamodb.cn-northwest-1.amazonaws.com.cn", + wantPartition: ".amazonaws.com.cn", }, } for _, tt := range tests { tt := tt t.Run(tt.desc, func(t *testing.T) { require.Equal(t, tt.wantURI, DynamoDBURIForRegion(tt.region)) + info, err := ParseDynamoDBEndpoint(tt.wantURI) + require.NoError(t, err, "endpoint generated from region could not be parsed.") + require.Equal(t, tt.region, info.Region) + require.Equal(t, "dynamodb", info.Service) + require.Equal(t, tt.wantPartition, info.Partition) }) } } @@ -630,7 +639,7 @@ func TestParseDynamoDBEndpoint(t *testing.T) { } for _, tt := range tests { tt := tt - t.Run("detects invalid endpoint with"+tt.desc, func(t *testing.T) { + t.Run("detects invalid endpoint with "+tt.desc, func(t *testing.T) { t.Parallel() info, err := ParseDynamoDBEndpoint(tt.endpoint) require.Error(t, err, "endpoint %s should be invalid", tt.endpoint) diff --git a/lib/cloud/mocks/aws.go b/lib/cloud/mocks/aws.go index 8b11a75af6bb8..f3771f878adc5 100644 --- a/lib/cloud/mocks/aws.go +++ b/lib/cloud/mocks/aws.go @@ -243,11 +243,11 @@ func (m *IAMMock) GetRolePolicyWithContext(ctx aws.Context, input *iam.GetRolePo defer m.mu.RUnlock() policy, ok := m.attachedRolePolicies[*input.RoleName] if !ok { - return nil, trace.NotFound("policy not found") + return nil, trace.NotFound("role policy %v not found", *input.RoleName) } policyDocument, ok := policy[*input.PolicyName] if !ok { - return nil, trace.NotFound("policy not found") + return nil, trace.NotFound("role %v policy name %v not found", *input.RoleName, *input.PolicyName) } return &iam.GetRolePolicyOutput{ PolicyDocument: &policyDocument, @@ -283,11 +283,11 @@ func (m *IAMMock) GetUserPolicyWithContext(ctx aws.Context, input *iam.GetUserPo defer m.mu.RUnlock() policy, ok := m.attachedUserPolicies[*input.UserName] if !ok { - return nil, trace.NotFound("policy not found") + return nil, trace.NotFound("user policy %v not found", *input.UserName) } policyDocument, ok := policy[*input.PolicyName] if !ok { - return nil, trace.NotFound("policy not found") + return nil, trace.NotFound("user %v policy name %v not found", *input.UserName, *input.PolicyName) } return &iam.GetUserPolicyOutput{ PolicyDocument: &policyDocument, diff --git a/lib/services/database_test.go b/lib/services/database_test.go index e412fdcce26da..eaa48fbbdcf80 100644 --- a/lib/services/database_test.go +++ b/lib/services/database_test.go @@ -35,6 +35,7 @@ import ( "github.com/aws/aws-sdk-go/service/rds" "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" + "github.com/google/go-cmp/cmp" "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" @@ -66,7 +67,7 @@ func TestDatabaseUnmarshal(t *testing.T) { require.NoError(t, err) actual, err := UnmarshalDatabase(data) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } // TestDatabaseMarshal verifies a marshaled database resource can be unmarshaled back. @@ -85,7 +86,7 @@ func TestDatabaseMarshal(t *testing.T) { require.NoError(t, err) actual, err := UnmarshalDatabase(data) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestValidateDatabase(t *testing.T) { @@ -217,7 +218,7 @@ func TestValidateDatabase(t *testing.T) { Protocol: defaults.ProtocolDynamoDB, AWS: types.AWS{ Region: "us-east-1", - AccountID: "1234567890", + AccountID: "123456789012", }, }, expectError: false, @@ -318,7 +319,7 @@ func TestDatabaseFromAzureDBServer(t *testing.T) { actual, err := NewDatabaseFromAzureServer(&server) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromAzureRedis(t *testing.T) { @@ -366,7 +367,7 @@ func TestDatabaseFromAzureRedis(t *testing.T) { actual, err := NewDatabaseFromAzureRedis(resourceInfo) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromAzureRedisEnterprise(t *testing.T) { @@ -428,7 +429,7 @@ func TestDatabaseFromAzureRedisEnterprise(t *testing.T) { actual, err := NewDatabaseFromAzureRedisEnterprise(armCluster, armDatabase) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. @@ -479,7 +480,7 @@ func TestDatabaseFromRDSInstance(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSInstance(instance) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } // TestDatabaseFromRDSInstance tests converting an RDS instance to a database resource. @@ -531,7 +532,7 @@ func TestDatabaseFromRDSInstanceNameOverride(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSInstance(instance) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } // TestDatabaseFromRDSCluster tests converting an RDS cluster to a database resource. @@ -587,7 +588,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSCluster(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("reader", func(t *testing.T) { @@ -611,7 +612,7 @@ func TestDatabaseFromRDSCluster(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("custom endpoints", func(t *testing.T) { @@ -723,7 +724,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSCluster(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("reader", func(t *testing.T) { @@ -748,7 +749,7 @@ func TestDatabaseFromRDSClusterNameOverride(t *testing.T) { require.NoError(t, err) actual, err := NewDatabaseFromRDSClusterReaderEndpoint(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("custom endpoints", func(t *testing.T) { @@ -858,7 +859,7 @@ func TestDatabaseFromRDSProxy(t *testing.T) { actual, err := NewDatabaseFromRDSProxy(dbProxy, port, tags) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("custom endpoint", func(t *testing.T) { @@ -894,7 +895,7 @@ func TestDatabaseFromRDSProxy(t *testing.T) { actual, err := NewDatabaseFromRDSProxyCustomEndpoint(dbProxy, dbProxyEndpoint, port, tags) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) } @@ -1080,7 +1081,7 @@ func TestDatabaseFromRedshiftCluster(t *testing.T) { actual, err := NewDatabaseFromRedshiftCluster(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("success with name override", func(t *testing.T) { @@ -1133,7 +1134,7 @@ func TestDatabaseFromRedshiftCluster(t *testing.T) { actual, err := NewDatabaseFromRedshiftCluster(cluster) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) }) t.Run("missing endpoint", func(t *testing.T) { @@ -1212,7 +1213,7 @@ func TestDatabaseFromElastiCacheConfigurationEndpoint(t *testing.T) { actual, err := NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromElastiCacheConfigurationEndpointNameOverride(t *testing.T) { @@ -1286,7 +1287,7 @@ func TestDatabaseFromElastiCacheConfigurationEndpointNameOverride(t *testing.T) actual, err := NewDatabaseFromElastiCacheConfigurationEndpoint(cluster, extraLabels) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromElastiCacheNodeGroups(t *testing.T) { @@ -1498,7 +1499,7 @@ func TestDatabaseFromMemoryDBCluster(t *testing.T) { actual, err := NewDatabaseFromMemoryDBCluster(cluster, extraLabels) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromRedshiftServerlessWorkgroup(t *testing.T) { @@ -1532,7 +1533,7 @@ func TestDatabaseFromRedshiftServerlessWorkgroup(t *testing.T) { actual, err := NewDatabaseFromRedshiftServerlessWorkgroup(workgroup, tags) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromRedshiftServerlessVPCEndpoint(t *testing.T) { @@ -1572,7 +1573,7 @@ func TestDatabaseFromRedshiftServerlessVPCEndpoint(t *testing.T) { actual, err := NewDatabaseFromRedshiftServerlessVPCEndpoint(endpoint, workgroup, tags) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestDatabaseFromMemoryDBClusterNameOverride(t *testing.T) { @@ -1621,7 +1622,7 @@ func TestDatabaseFromMemoryDBClusterNameOverride(t *testing.T) { actual, err := NewDatabaseFromMemoryDBCluster(cluster, extraLabels) require.NoError(t, err) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestExtraElastiCacheLabels(t *testing.T) { @@ -1750,7 +1751,7 @@ func TestExtraMemoryDBLabels(t *testing.T) { } actual := ExtraMemoryDBLabels(cluster, resourceTags, allSubnetGroups) - require.Equal(t, expected, actual) + require.Empty(t, cmp.Diff(expected, actual)) } func TestGetLabelEngineVersion(t *testing.T) { diff --git a/lib/services/role.go b/lib/services/role.go index e5f41b4d347d0..d842819bed286 100644 --- a/lib/services/role.go +++ b/lib/services/role.go @@ -2174,7 +2174,11 @@ func makeAlternativeNamesForAWSRole(db types.Database, user string) []string { } // If input database user is the short role name, try the full ARN. - return []string{awsutils.BuildRoleARN(user, metadata.Region, metadata.AccountID)} + roleARN, err := awsutils.BuildRoleARN(user, metadata.Region, metadata.AccountID) + if err != nil { + return nil + } + return []string{roleARN} } // DatabaseNameMatcher matches a role against database name. diff --git a/lib/srv/app/cloud.go b/lib/srv/app/cloud.go index 2b95d1f055711..210cc40b96d38 100644 --- a/lib/srv/app/cloud.go +++ b/lib/srv/app/cloud.go @@ -26,7 +26,6 @@ import ( "time" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/aws/credentials/ec2rolecreds" "github.com/aws/aws-sdk-go/aws/credentials/ssocreds" "github.com/aws/aws-sdk-go/aws/credentials/stscreds" @@ -37,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/constants" "github.com/gravitational/teleport/lib/tlsca" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // Cloud provides cloud provider access related methods such as generating @@ -63,7 +63,7 @@ func (r *AWSSigninRequest) CheckAndSetDefaults() error { if r.Identity == nil { return trace.BadParameter("missing Identity") } - _, err := arn.Parse(r.Identity.RouteToApp.AWSRoleARN) + _, err := awsutils.ParseRoleARN(r.Identity.RouteToApp.AWSRoleARN) if err != nil { return trace.Wrap(err) } diff --git a/lib/srv/app/server_test.go b/lib/srv/app/server_test.go index 6526f557ad9f3..24f19b170296a 100644 --- a/lib/srv/app/server_test.go +++ b/lib/srv/app/server_test.go @@ -188,7 +188,7 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite { Spec: types.RoleSpecV6{ Allow: types.RoleConditions{ AppLabels: roleAppLabels, - AWSRoleARNs: []string{"readonly"}, + AWSRoleARNs: []string{"arn:aws:iam::123456789012:role/readonly"}, }, }, } @@ -282,7 +282,7 @@ func SetUpSuiteWithConfig(t *testing.T, config suiteConfig) *Suite { s.clientCertificate = s.generateCertificate(t, s.user, "foo.example.com", "") // Generate certificate for AWS console application. - s.awsConsoleCertificate = s.generateCertificate(t, s.user, "aws.example.com", "readonly") + s.awsConsoleCertificate = s.generateCertificate(t, s.user, "aws.example.com", "arn:aws:iam::123456789012:role/readonly") lockWatcher, err := services.NewLockWatcher(s.closeContext, services.LockWatcherConfig{ ResourceWatcherConfig: services.ResourceWatcherConfig{ diff --git a/lib/srv/db/cassandra/handshake.go b/lib/srv/db/cassandra/handshake.go index e19554de0227b..73c66c1dae543 100644 --- a/lib/srv/db/cassandra/handshake.go +++ b/lib/srv/db/cassandra/handshake.go @@ -199,7 +199,11 @@ func (a *authAWSSigV4Auth) getSigV4Authenticator(username string) (gocql.Authent } region := a.ses.Database.GetAWS().Region accountID := a.ses.Database.GetAWS().AccountID - cred, err := stscreds.NewCredentials(session, awsutils.BuildRoleARN(username, region, accountID)).Get() + roleARN, err := awsutils.BuildRoleARN(username, region, accountID) + if err != nil { + return nil, trace.Wrap(err) + } + cred, err := stscreds.NewCredentials(session, roleARN).Get() if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/cloud/meta.go b/lib/srv/db/cloud/meta.go index 4fba179751f10..599fef830e979 100644 --- a/lib/srv/db/cloud/meta.go +++ b/lib/srv/db/cloud/meta.go @@ -21,7 +21,6 @@ import ( "strings" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" "github.com/aws/aws-sdk-go/service/elasticache" "github.com/aws/aws-sdk-go/service/elasticache/elasticacheiface" "github.com/aws/aws-sdk-go/service/memorydb" @@ -155,25 +154,20 @@ func (m *Metadata) fetchRDSProxyMetadata(ctx context.Context, database types.Dat // fetchRedshiftMetadata fetches metadata for the provided Redshift database. func (m *Metadata) fetchRedshiftMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { - redshift, err := m.cfg.Clients.GetAWSRedshiftClient(database.GetAWS().Region) + meta := database.GetAWS() + redshift, err := m.cfg.Clients.GetAWSRedshiftClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } - cluster, err := describeRedshiftCluster(ctx, redshift, database.GetAWS().Redshift.ClusterID) + cluster, err := describeRedshiftCluster(ctx, redshift, meta.Redshift.ClusterID) if err != nil { return nil, trace.Wrap(err) } - parsedARN, err := arn.Parse(aws.StringValue(cluster.ClusterNamespaceArn)) + fetchedMetadata, err := services.MetadataFromRedshiftCluster(cluster) if err != nil { return nil, trace.Wrap(err) } - return &types.AWS{ - Region: parsedARN.Region, - AccountID: parsedARN.AccountID, - Redshift: types.Redshift{ - ClusterID: aws.StringValue(cluster.ClusterIdentifier), - }, - }, nil + return fetchedMetadata, nil } // fetchRedshiftServerlessMetadata fetches metadata for the provided Redshift @@ -193,33 +187,35 @@ func (m *Metadata) fetchRedshiftServerlessMetadata(ctx context.Context, database // fetchElastiCacheMetadata fetches metadata for the provided ElastiCache database. func (m *Metadata) fetchElastiCacheMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { - elastiCacheClient, err := m.cfg.Clients.GetAWSElastiCacheClient(database.GetAWS().Region) + meta := database.GetAWS() + elastiCacheClient, err := m.cfg.Clients.GetAWSElastiCacheClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } - cluster, err := describeElastiCacheCluster(ctx, elastiCacheClient, database.GetAWS().ElastiCache.ReplicationGroupID) + cluster, err := describeElastiCacheCluster(ctx, elastiCacheClient, meta.ElastiCache.ReplicationGroupID) if err != nil { return nil, trace.Wrap(err) } // Endpoint type does not change. - endpointType := database.GetAWS().ElastiCache.EndpointType + endpointType := meta.ElastiCache.EndpointType return services.MetadataFromElastiCacheCluster(cluster, endpointType) } // fetchMemoryDBMetadata fetches metadata for the provided MemoryDB database. func (m *Metadata) fetchMemoryDBMetadata(ctx context.Context, database types.Database) (*types.AWS, error) { - memoryDBClient, err := m.cfg.Clients.GetAWSMemoryDBClient(database.GetAWS().Region) + meta := database.GetAWS() + memoryDBClient, err := m.cfg.Clients.GetAWSMemoryDBClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } - cluster, err := describeMemoryDBCluster(ctx, memoryDBClient, database.GetAWS().MemoryDB.ClusterName) + cluster, err := describeMemoryDBCluster(ctx, memoryDBClient, meta.MemoryDB.ClusterName) if err != nil { return nil, trace.Wrap(err) } // Endpoint type does not change. - endpointType := database.GetAWS().MemoryDB.EndpointType + endpointType := meta.MemoryDB.EndpointType return services.MetadataFromMemoryDBCluster(cluster, endpointType) } diff --git a/lib/srv/db/cloud/users/elasticache.go b/lib/srv/db/cloud/users/elasticache.go index 100e7dc6ac69c..bc8b0a5f7b787 100644 --- a/lib/srv/db/cloud/users/elasticache.go +++ b/lib/srv/db/cloud/users/elasticache.go @@ -66,11 +66,12 @@ func (f *elastiCacheFetcher) GetType() string { // FetchDatabaseUsers fetches users for provided database. Implements Fetcher. func (f *elastiCacheFetcher) FetchDatabaseUsers(ctx context.Context, database types.Database) ([]User, error) { - if len(database.GetAWS().ElastiCache.UserGroupIDs) == 0 { + meta := database.GetAWS() + if len(meta.ElastiCache.UserGroupIDs) == 0 { return nil, nil } - client, err := f.cfg.Clients.GetAWSElastiCacheClient(database.GetAWS().Region) + client, err := f.cfg.Clients.GetAWSElastiCacheClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } @@ -81,8 +82,8 @@ func (f *elastiCacheFetcher) FetchDatabaseUsers(ctx context.Context, database ty } users := []User{} - for _, userGroupID := range database.GetAWS().ElastiCache.UserGroupIDs { - managedUsers, err := f.getManagedUsersForGroup(ctx, database.GetAWS().Region, userGroupID, client) + for _, userGroupID := range meta.ElastiCache.UserGroupIDs { + managedUsers, err := f.getManagedUsersForGroup(ctx, meta.Region, userGroupID, client) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/cloud/users/helpers.go b/lib/srv/db/cloud/users/helpers.go index 7d074bdc66e1b..908a33ce0a7cf 100644 --- a/lib/srv/db/cloud/users/helpers.go +++ b/lib/srv/db/cloud/users/helpers.go @@ -155,7 +155,8 @@ func genRandomPassword(length int) (string, error) { func newSecretStore(database types.Database, clients cloud.Clients) (secrets.Secrets, error) { secretStoreConfig := database.GetSecretStore() - client, err := clients.GetAWSSecretsManagerClient(database.GetAWS().Region) + meta := database.GetAWS() + client, err := clients.GetAWSSecretsManagerClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/cloud/users/memorydb.go b/lib/srv/db/cloud/users/memorydb.go index 22552064c5aab..9a3ca86e621a8 100644 --- a/lib/srv/db/cloud/users/memorydb.go +++ b/lib/srv/db/cloud/users/memorydb.go @@ -67,11 +67,12 @@ func (f *memoryDBFetcher) GetType() string { // FetchDatabaseUsers fetches users for provided database. Implements Fetcher. func (f *memoryDBFetcher) FetchDatabaseUsers(ctx context.Context, database types.Database) ([]User, error) { - if database.GetAWS().MemoryDB.ACLName == "" { + meta := database.GetAWS() + if meta.MemoryDB.ACLName == "" { return nil, nil } - client, err := f.cfg.Clients.GetAWSMemoryDBClient(database.GetAWS().Region) + client, err := f.cfg.Clients.GetAWSMemoryDBClient(meta.Region) if err != nil { return nil, trace.Wrap(err) } @@ -82,7 +83,7 @@ func (f *memoryDBFetcher) FetchDatabaseUsers(ctx context.Context, database types } users := []User{} - mdbUsers, err := f.getManagedUsersForACL(ctx, database.GetAWS().Region, database.GetAWS().MemoryDB.ACLName, client) + mdbUsers, err := f.getManagedUsersForACL(ctx, meta.Region, meta.MemoryDB.ACLName, client) if err != nil { return nil, trace.Wrap(err) } diff --git a/lib/srv/db/common/auth.go b/lib/srv/db/common/auth.go index 695338dbd3bbe..ed993435aa2cc 100644 --- a/lib/srv/db/common/auth.go +++ b/lib/srv/db/common/auth.go @@ -30,8 +30,6 @@ import ( "github.com/Azure/azure-sdk-for-go/sdk/azcore/arm" "github.com/Azure/azure-sdk-for-go/sdk/azcore/policy" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/arn" - "github.com/aws/aws-sdk-go/service/iam" "github.com/aws/aws-sdk-go/service/rds/rdsutils" "github.com/aws/aws-sdk-go/service/redshift" "github.com/aws/aws-sdk-go/service/redshiftserverless" @@ -42,7 +40,6 @@ import ( "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" - awsutils "github.com/gravitational/teleport/api/utils/aws" azureutils "github.com/gravitational/teleport/api/utils/azure" "github.com/gravitational/teleport/api/utils/retryutils" libauth "github.com/gravitational/teleport/lib/auth" @@ -54,6 +51,7 @@ import ( dbiam "github.com/gravitational/teleport/lib/srv/db/common/iam" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // azureVirtualMachineCacheTTL is the default TTL for Azure virtual machine @@ -160,14 +158,15 @@ func NewAuth(config AuthConfig) (Auth, error) { // GetRDSAuthToken returns authorization token that will be used as a password // when connecting to RDS and Aurora databases. func (a *dbAuth) GetRDSAuthToken(sessionCtx *Session) (string, error) { - awsSession, err := a.cfg.Clients.GetAWSSession(sessionCtx.Database.GetAWS().Region) + meta := sessionCtx.Database.GetAWS() + awsSession, err := a.cfg.Clients.GetAWSSession(meta.Region) if err != nil { return "", trace.Wrap(err) } a.cfg.Log.Debugf("Generating RDS auth token for %s.", sessionCtx) token, err := rdsutils.BuildAuthToken( sessionCtx.Database.GetURI(), - sessionCtx.Database.GetAWS().Region, + meta.Region, sessionCtx.DatabaseUser, awsSession.Config.Credentials) if err != nil { @@ -191,13 +190,14 @@ permissions (note that IAM changes may take a few minutes to propagate): // GetRedshiftAuthToken returns authorization token that will be used as a // password when connecting to Redshift databases. func (a *dbAuth) GetRedshiftAuthToken(sessionCtx *Session) (string, string, error) { - awsSession, err := a.cfg.Clients.GetAWSSession(sessionCtx.Database.GetAWS().Region) + meta := sessionCtx.Database.GetAWS() + awsSession, err := a.cfg.Clients.GetAWSSession(meta.Region) if err != nil { return "", "", trace.Wrap(err) } a.cfg.Log.Debugf("Generating Redshift auth token for %s.", sessionCtx) resp, err := redshift.New(awsSession).GetClusterCredentials(&redshift.GetClusterCredentialsInput{ - ClusterIdentifier: aws.String(sessionCtx.Database.GetAWS().Redshift.ClusterID), + ClusterIdentifier: aws.String(meta.Redshift.ClusterID), DbUser: aws.String(sessionCtx.DatabaseUser), DbName: aws.String(sessionCtx.DatabaseName), // TODO(r0mant): Do not auto-create database account if DbUser doesn't @@ -232,12 +232,12 @@ func (a *dbAuth) GetRedshiftServerlessAuthToken(ctx context.Context, sessionCtx // example, an IAM role "arn:aws:iam::1234567890:role/my-role-name" will be // mapped to a Postgres user "IAMR:my-role-name" inside the database. So we // first need to assume this IAM role before getting auth token. - awsMetadata := sessionCtx.Database.GetAWS() - roleARN, err := redshiftServerlessUsernameToRoleARN(awsMetadata, sessionCtx.DatabaseUser) + meta := sessionCtx.Database.GetAWS() + roleARN, err := redshiftServerlessUsernameToRoleARN(meta, sessionCtx.DatabaseUser) if err != nil { return "", "", trace.Wrap(err) } - client, err := a.cfg.Clients.GetAWSRedshiftServerlessClientForRole(ctx, awsMetadata.Region, roleARN) + client, err := a.cfg.Clients.GetAWSRedshiftServerlessClientForRole(ctx, meta.Region, roleARN) if err != nil { return "", "", trace.AccessDenied(`Could not generate Redshift Serverless auth token: @@ -250,7 +250,7 @@ Make sure that IAM role %q has a trust relationship with Teleport database agent // Now make the API call to generate the temporary credentials. a.cfg.Log.Debugf("Generating Redshift Serverless auth token for %s.", sessionCtx) resp, err := client.GetCredentialsWithContext(ctx, &redshiftserverless.GetCredentialsInput{ - WorkgroupName: aws.String(awsMetadata.RedshiftServerless.WorkgroupName), + WorkgroupName: aws.String(meta.RedshiftServerless.WorkgroupName), DbName: aws.String(sessionCtx.DatabaseName), }) if err != nil { @@ -816,32 +816,11 @@ func matchAzureResourceName(resourceID, name string) bool { // redshiftServerlessUsernameToRoleARN converts a database username to AWS role // ARN for a Redshift Serverless database. func redshiftServerlessUsernameToRoleARN(aws types.AWS, username string) (string, error) { - switch { // These are in-database usernames created when logged in as IAM // users/roles. We will enforce Teleport users to provide IAM roles // instead. - case strings.HasPrefix(username, "IAM:") || strings.HasPrefix(username, "IAMR:"): + if strings.HasPrefix(username, "IAM:") || strings.HasPrefix(username, "IAMR:") { return "", trace.BadParameter("expecting name or ARN of an AWS IAM role but got %v", username) - - case arn.IsARN(username): - if parsedARN, err := arn.Parse(username); err != nil { - return "", trace.Wrap(err) - } else if parsedARN.Service != iam.ServiceName || !strings.HasPrefix(parsedARN.Resource, "role/") { - return "", trace.BadParameter("expecting name or ARN of an AWS IAM role but got %v", username) - } - return username, nil - - default: - resource := username - if !strings.Contains(resource, "/") { - resource = fmt.Sprintf("role/%s", username) - } - - return arn.ARN{ - Partition: awsutils.GetPartitionFromRegion(aws.Region), - Service: iam.ServiceName, - AccountID: aws.AccountID, - Resource: resource, - }.String(), nil } + return awsutils.BuildRoleARN(username, aws.Region, aws.AccountID) } diff --git a/lib/srv/db/dynamodb/engine.go b/lib/srv/db/dynamodb/engine.go index afb577294d130..ccc060bdeabbf 100644 --- a/lib/srv/db/dynamodb/engine.go +++ b/lib/srv/db/dynamodb/engine.go @@ -171,15 +171,15 @@ func (e *Engine) process(ctx context.Context, req *http.Request) (err error) { defer req.Body.Close() } - var responseStatusCode uint32 re, err := e.resolveEndpoint(req) if err != nil { // special error case where we couldn't resolve the endpoint, just emit using the configured URI. - e.emitAuditEvent(req, e.sessionCtx.Database.GetURI(), responseStatusCode, err) + e.emitAuditEvent(req, e.sessionCtx.Database.GetURI(), 0, err) return trace.Wrap(err) } // emit an audit event regardless of failure, but using the resolved endpoint. + var responseStatusCode uint32 defer func() { e.emitAuditEvent(req, re.URL, responseStatusCode, err) }() @@ -199,7 +199,11 @@ func (e *Engine) process(ctx context.Context, req *http.Request) (err error) { return trace.Wrap(err) } - roleArn := libaws.BuildRoleARN(e.sessionCtx.DatabaseUser, re.SigningRegion, e.sessionCtx.Database.GetAWS().AccountID) + roleArn, err := libaws.BuildRoleARN(e.sessionCtx.DatabaseUser, + re.SigningRegion, e.sessionCtx.Database.GetAWS().AccountID) + if err != nil { + return trace.Wrap(err) + } signedReq, err := e.signingSvc.SignRequest(e.Context, outReq, &libaws.SigningCtx{ SigningName: re.SigningName, diff --git a/lib/srv/db/dynamodb_test.go b/lib/srv/db/dynamodb_test.go index 4e175202474e8..369817f511895 100644 --- a/lib/srv/db/dynamodb_test.go +++ b/lib/srv/db/dynamodb_test.go @@ -223,7 +223,7 @@ func withDynamoDB(name string, opts ...dynamodb.TestServerOption) withDatabaseOp DynamicLabels: dynamicLabels, AWS: types.AWS{ Region: "us-west-1", - AccountID: "12345", + AccountID: "123456789012", }, TLS: types.DatabaseTLS{ // Set CA, otherwise the engine will attempt to download and use the AWS CA. diff --git a/lib/utils/aws/aws.go b/lib/utils/aws/aws.go index 89abed02babdd..dfdba5feea7b8 100644 --- a/lib/utils/aws/aws.go +++ b/lib/utils/aws/aws.go @@ -31,6 +31,7 @@ import ( v4 "github.com/aws/aws-sdk-go/aws/signer/v4" "github.com/aws/aws-sdk-go/service/iam" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" apievents "github.com/gravitational/teleport/api/types/events" apiawsutils "github.com/gravitational/teleport/api/utils/aws" @@ -241,11 +242,15 @@ func filterHeaders(r *http.Request, headers []string) []string { // FilterAWSRoles returns role ARNs from the provided list that belong to the // specified AWS account ID. // -// If AWS account ID is empty, all roles are returned. +// If AWS account ID is empty, all valid AWS IAM roles are returned. func FilterAWSRoles(arns []string, accountID string) (result Roles) { for _, roleARN := range arns { - parsed, err := arn.Parse(roleARN) - if err != nil || (accountID != "" && parsed.AccountID != accountID) { + parsed, err := ParseRoleARN(roleARN) + if err != nil { + logrus.Warnf("skipping invalid AWS role ARN: %v", err) + continue + } + if accountID != "" && parsed.AccountID != accountID { continue } @@ -256,13 +261,9 @@ func FilterAWSRoles(arns []string, accountID string) (result Roles) { // arn:aws:iam::1234567890:role/EC2FullAccess (display: EC2FullAccess) // arn:aws:iam::1234567890:role/path/to/customrole (display: customrole) parts := strings.Split(parsed.Resource, "/") - numParts := len(parts) - if numParts < 2 || parts[0] != "role" { - continue - } result = append(result, Role{ Name: strings.Join(parts[1:], "/"), - Display: parts[numParts-1], + Display: parts[len(parts)-1], ARN: roleARN, }) } @@ -348,37 +349,89 @@ func isJSON(contentType string) bool { } // BuildRoleARN constructs a string AWS ARN from a username, region, and account ID. -func BuildRoleARN(username, region, accountID string) string { +// If username is an AWS ARN, this function checks that the ARN is an AWS IAM Role ARN +// in the correct partition and account. +func BuildRoleARN(username, region, accountID string) (string, error) { + partition := apiawsutils.GetPartitionFromRegion(region) if arn.IsARN(username) { - return username + // sanity check the given username role ARN. + parsed, err := ParseRoleARN(username) + if err != nil { + return "", trace.Wrap(err) + } + // don't check for empty accountID - callers do not always pass an account ID, + // and it's only absolutely required if we need to build the role ARN below. + if err := CheckARNPartitionAndAccount(parsed, partition, accountID); err != nil { + return "", trace.Wrap(err) + } + return username, nil } resource := username - if !strings.Contains(resource, "/") { + if !strings.HasPrefix(resource, "role/") { resource = fmt.Sprintf("role/%s", username) } - return arn.ARN{ - Partition: apiawsutils.GetPartitionFromRegion(region), + roleARN := &arn.ARN{ + Partition: partition, Service: iam.ServiceName, AccountID: accountID, Resource: resource, - }.String() + } + if err := checkRoleARN(roleARN); err != nil { + return "", trace.Wrap(err) + } + return roleARN.String(), nil } // ValidateRoleARNAndExtractRoleName validates the role ARN and extracts the // short role name from it. func ValidateRoleARNAndExtractRoleName(roleARN, wantPartition, wantAccountID string) (string, error) { - role, err := arn.Parse(roleARN) + role, err := ParseRoleARN(roleARN) if err != nil { return "", trace.Wrap(err) } - if !strings.HasPrefix(role.Resource, "role/") || role.Service != iam.ServiceName { - return "", trace.BadParameter("%q is not an IAM role", roleARN) + if err := CheckARNPartitionAndAccount(role, wantPartition, wantAccountID); err != nil { + return "", trace.Wrap(err) } - if role.Partition != wantPartition { - return "", trace.BadParameter("expecting AWS partition %q but got %q", wantPartition, role.Partition) + return strings.TrimPrefix(role.Resource, "role/"), nil +} + +// ParseRoleARN parses an AWS ARN and checks that the ARN is +// for an IAM Role resource. +func ParseRoleARN(roleARN string) (*arn.ARN, error) { + role, err := arn.Parse(roleARN) + if err != nil { + return nil, trace.BadParameter("invalid AWS ARN: %v", err) } - if role.AccountID != wantAccountID { - return "", trace.BadParameter("expecting AWS account ID %q but got %q", wantAccountID, role.AccountID) + if err := checkRoleARN(&role); err != nil { + return nil, trace.Wrap(err) } - return strings.TrimPrefix(role.Resource, "role/"), nil + return &role, nil +} + +// checkRoleARN returns whether a parsed ARN is for an IAM Role resource. +// Example role ARN: arn:aws:iam::123456789012:role/some-role-name +func checkRoleARN(parsed *arn.ARN) error { + parts := strings.Split(parsed.Resource, "/") + if parts[0] != "role" || parsed.Service != iam.ServiceName { + return trace.BadParameter("%q is not an AWS IAM role ARN", parsed) + } + if len(parts) < 2 { + return trace.BadParameter("%q is missing AWS IAM role name", parsed) + } + if err := apiawsutils.IsValidAccountID(parsed.AccountID); err != nil { + return trace.BadParameter("%q invalid account ID: %v", parsed, err) + } + return nil +} + +// CheckARNPartitionAndAccount checks an AWS ARN against an expected AWS partition and account ID. +// An empty expected AWS partition or account ID is not checked. +func CheckARNPartitionAndAccount(ARN *arn.ARN, wantPartition, wantAccountID string) error { + if ARN.Partition != wantPartition && wantPartition != "" { + return trace.BadParameter("expected AWS partition %q but got %q", wantPartition, ARN.Partition) + } + if ARN.AccountID != wantAccountID && wantAccountID != "" { + return trace.BadParameter("expected AWS account ID %q but got %q", wantAccountID, ARN.AccountID) + } + return nil } diff --git a/lib/utils/aws/aws_test.go b/lib/utils/aws/aws_test.go index 20c476573b98e..25407a6575283 100644 --- a/lib/utils/aws/aws_test.go +++ b/lib/utils/aws/aws_test.go @@ -255,3 +255,140 @@ func TestValidateRoleARNAndExtractRoleName(t *testing.T) { }) } } + +func TestParseRoleARN(t *testing.T) { + tests := map[string]struct { + arn string + wantErrContains string + }{ + "valid role arn": { + arn: "arn:aws:iam::123456789012:role/test-role", + }, + "valid sso role arn": { + arn: "arn:aws:iam::123456789012:role/aws-reserved/sso.amazonaws.com/us-west-2/AWSReservedSSO_AWSPowerUserAccess_xxxxxxxxx", + }, + "valid service role arn": { + arn: "arn:aws:iam::123456789012:role/aws-service-role/redshift.amazonaws.com/AWSServiceRoleForRedshift", + }, + "arn fails to parse": { + arn: "foobar", + wantErrContains: "invalid AWS ARN", + }, + "sts arn is not iam": { + arn: "arn:aws:sts::123456789012:federated-user/Alice", + wantErrContains: "not an AWS IAM role", + }, + "iam arn is not a role": { + arn: "arn:aws:iam::123456789012:user/test-user", + wantErrContains: "not an AWS IAM role", + }, + "iam role arn is missing role name": { + arn: "arn:aws:iam::123456789012:role", + wantErrContains: "missing AWS IAM role name", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + got, err := ParseRoleARN(tt.arn) + if tt.wantErrContains != "" { + require.Error(t, err, err.Error()) + require.ErrorContains(t, err, tt.wantErrContains) + return + } + require.NoError(t, err) + require.NotNil(t, got) + }) + } +} + +func TestBuildRoleARN(t *testing.T) { + tests := map[string]struct { + user string + region string + accountID string + wantErrContains string + wantARN string + }{ + "valid role arn in correct partition and account": { + user: "arn:aws:iam::123456789012:role/test-role", + region: "us-west-1", + accountID: "123456789012", + wantARN: "arn:aws:iam::123456789012:role/test-role", + }, + "valid role arn in correct account and default partition": { + user: "arn:aws:iam::123456789012:role/test-role", + region: "", + accountID: "123456789012", + wantARN: "arn:aws:iam::123456789012:role/test-role", + }, + "valid role arn in default partition and account": { + user: "arn:aws:iam::123456789012:role/test-role", + region: "", + accountID: "", + wantARN: "arn:aws:iam::123456789012:role/test-role", + }, + "role name with prefix in default partition and account": { + user: "role/test-role", + region: "", + accountID: "123456789012", + wantARN: "arn:aws:iam::123456789012:role/test-role", + }, + "role name in default partition and account": { + user: "test-role", + region: "", + accountID: "123456789012", + wantARN: "arn:aws:iam::123456789012:role/test-role", + }, + "role name in china partition and account": { + user: "test-role", + region: "cn-north-1", + accountID: "123456789012", + wantARN: "arn:aws-cn:iam::123456789012:role/test-role", + }, + "valid ARN is not an IAM role ARN": { + user: "arn:aws:iam::123456789012:user/test-user", + region: "", + accountID: "", + wantErrContains: "not an AWS IAM role", + }, + "valid role arn in different partition": { + user: "arn:aws-cn:iam::123456789012:role/test-role", + region: "us-west-1", + accountID: "", + wantErrContains: `expected AWS partition "aws" but got "aws-cn"`, + }, + "valid role arn in different account": { + user: "arn:aws:iam::123456789012:role/test-role", + region: "us-west-1", + accountID: "111222333444", + wantErrContains: `expected AWS account ID "111222333444" but got "123456789012"`, + }, + "role name with invalid account characters": { + user: "test-role", + region: "", + accountID: "12345678901f", + wantErrContains: "must be 12-digit", + }, + "role name with invalid account id length": { + user: "test-role", + region: "", + accountID: "1234567890123", + wantErrContains: "must be 12-digit", + }, + } + + for name, tt := range tests { + t.Run(name, func(t *testing.T) { + got, err := BuildRoleARN(tt.user, tt.region, tt.accountID) + if tt.wantErrContains != "" { + require.Error(t, err) + require.ErrorContains(t, err, tt.wantErrContains) + return + } + require.NoError(t, err) + require.NotEmpty(t, got) + require.Equal(t, tt.wantARN, got) + }) + } +} diff --git a/lib/utils/aws/signing.go b/lib/utils/aws/signing.go index f7b856938bcde..c51a79bc7e054 100644 --- a/lib/utils/aws/signing.go +++ b/lib/utils/aws/signing.go @@ -111,9 +111,12 @@ func (sc *SigningCtx) Check(clock clockwork.Clock) error { return trace.BadParameter("missing AWS Role ARN") case sc.Expiry.Before(clock.Now()): return trace.BadParameter("AWS SigV4 expiry has already expired") - default: - return nil } + _, err := ParseRoleARN(sc.AWSRoleArn) + if err != nil { + return trace.Wrap(err) + } + return nil } // SignRequest creates a new HTTP request and rewrites the header from the original request and returns a new diff --git a/tool/tsh/app_aws_test.go b/tool/tsh/app_aws_test.go index d6133eff58352..49b5ed7085209 100644 --- a/tool/tsh/app_aws_test.go +++ b/tool/tsh/app_aws_test.go @@ -139,7 +139,7 @@ func makeUserWithAWSRole(t *testing.T) (types.User, types.Role) { types.Wildcard: apiutils.Strings{types.Wildcard}, }, AWSRoleARNs: []string{ - "arn:aws:iam::123456890:role/some-aws-role", + "arn:aws:iam::123456789012:role/some-aws-role", }, }, }) diff --git a/tool/tsh/db_test.go b/tool/tsh/db_test.go index d68220e284717..22765efdb3afe 100644 --- a/tool/tsh/db_test.go +++ b/tool/tsh/db_test.go @@ -93,7 +93,7 @@ func TestDatabaseLogin(t *testing.T) { Protocol: defaults.ProtocolDynamoDB, URI: "", // uri can be blank for DynamoDB, it will be derived from the region and requests. AWS: servicecfg.DatabaseAWS{ - AccountID: "12345", + AccountID: "123456789012", ExternalID: "123123123", Region: "us-west-1", }, From 0a629bc7c3064f9909be48d4bdd0f4619b75348b Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Wed, 15 Mar 2023 18:06:12 -0700 Subject: [PATCH 2/4] Add configuration for AWS assume_role_arn --- api/types/database.go | 5 ++ api/types/database_test.go | 32 ++++++++- lib/config/configuration.go | 30 +++++--- lib/config/configuration_test.go | 100 ++++++++++++++++++-------- lib/config/database.go | 7 +- lib/config/database_test.go | 108 +++++++++++++++++++++++------ lib/config/fileconf.go | 18 +++++ lib/config/fileconf_test.go | 64 ++++++++++++++++- lib/config/testdata_test.go | 7 ++ lib/service/servicecfg/database.go | 21 +++++- lib/services/database.go | 32 +++++++-- lib/services/database_test.go | 36 ++++++++++ lib/services/matchers.go | 26 +++++++ tool/teleport/common/teleport.go | 2 + 14 files changed, 412 insertions(+), 76 deletions(-) diff --git a/api/types/database.go b/api/types/database.go index 05dcaa97e6a09..bac8370954c74 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -679,6 +679,11 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { } } + if d.Spec.AWS.AssumeRoleARN == "" && d.Spec.AWS.ExternalID != "" { + return trace.BadParameter("AWS database %q has external_id %q, but assume_role_arn is missing", + d.GetName(), d.Spec.AWS.ExternalID) + } + // Validate Cloud SQL specific configuration. switch { case d.Spec.GCP.ProjectID != "" && d.Spec.GCP.InstanceID == "": diff --git a/api/types/database_test.go b/api/types/database_test.go index 11f6ee4c84ef9..9ed3debd26904 100644 --- a/api/types/database_test.go +++ b/api/types/database_test.go @@ -549,6 +549,8 @@ func TestDynamoDBConfig(t *testing.T) { uri string region string account string + roleARN string + externalID string wantSpec DatabaseSpecV3 wantErrMsg string }{ @@ -564,6 +566,22 @@ func TestDynamoDBConfig(t *testing.T) { }, }, }, + { + desc: "account and region and assume role is correct", + region: "us-west-1", + account: "123456789012", + roleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + externalID: "externalid123", + wantSpec: DatabaseSpecV3{ + URI: "aws://dynamodb.us-west-1.amazonaws.com", + AWS: AWS{ + Region: "us-west-1", + AccountID: "123456789012", + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalid123", + }, + }, + }, { desc: "account and AWS URI and empty region is correct", uri: "dynamodb.us-west-1.amazonaws.com", @@ -658,6 +676,14 @@ func TestDynamoDBConfig(t *testing.T) { account: "12345", wantErrMsg: "must be 12-digit", }, + { + desc: "configured external ID but not assume role is an error", + uri: "localhost:8080", + region: "us-west-1", + account: "123456789012", + externalID: "externalid123", + wantErrMsg: "assume_role_arn is missing", + }, } for _, tt := range tests { @@ -670,8 +696,10 @@ func TestDynamoDBConfig(t *testing.T) { Protocol: "dynamodb", URI: tt.uri, AWS: AWS{ - Region: tt.region, - AccountID: tt.account, + Region: tt.region, + AccountID: tt.account, + AssumeRoleARN: tt.roleARN, + ExternalID: tt.externalID, }, }) if tt.wantErrMsg != "" { diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 2607f5e4cf966..47ca121cbfa5b 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -144,6 +144,8 @@ type CommandLineFlags struct { DatabaseAWSRegion string // DatabaseAWSAccountID is an optional AWS account ID e.g. when using Keyspaces. DatabaseAWSAccountID string + // DatabaseAWSAssumeRoleARN is an optional AWS IAM role ARN to assume when accessing the database. + DatabaseAWSAssumeRoleARN string // DatabaseAWSExternalID is an optional AWS external ID used to enable assuming an AWS role across accounts. DatabaseAWSExternalID string // DatabaseAWSRedshiftClusterID is Redshift cluster identifier. @@ -1208,9 +1210,13 @@ func applyDiscoveryConfig(fc *FileConfig, cfg *servicecfg.Config) error { services.AWSMatcher{ Types: matcher.Types, Regions: matcher.Regions, - Tags: matcher.Tags, - Params: installParams, - SSM: &services.AWSSSM{DocumentName: matcher.SSM.DocumentName}, + AssumeRole: services.AssumeRole{ + RoleARN: matcher.AssumeRoleARN, + ExternalID: matcher.ExternalID, + }, + Tags: matcher.Tags, + Params: installParams, + SSM: &services.AWSSSM{DocumentName: matcher.SSM.DocumentName}, }) } @@ -1314,6 +1320,10 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error { Types: matcher.Types, Regions: matcher.Regions, Tags: matcher.Tags, + AssumeRole: services.AssumeRole{ + RoleARN: matcher.AssumeRoleARN, + ExternalID: matcher.ExternalID, + }, }) } for _, matcher := range fc.Databases.AzureMatchers { @@ -1363,9 +1373,10 @@ func applyDatabasesConfig(fc *FileConfig, cfg *servicecfg.Config) error { Mode: servicecfg.TLSMode(database.TLS.Mode), }, AWS: servicecfg.DatabaseAWS{ - AccountID: database.AWS.AccountID, - ExternalID: database.AWS.ExternalID, - Region: database.AWS.Region, + AccountID: database.AWS.AccountID, + AssumeRoleARN: database.AWS.AssumeRoleARN, + ExternalID: database.AWS.ExternalID, + Region: database.AWS.Region, Redshift: servicecfg.DatabaseAWSRedshift{ ClusterID: database.AWS.Redshift.ClusterID, }, @@ -1898,9 +1909,10 @@ func Configure(clf *CommandLineFlags, cfg *servicecfg.Config, legacyAppFlags boo CACert: caBytes, }, AWS: servicecfg.DatabaseAWS{ - Region: clf.DatabaseAWSRegion, - AccountID: clf.DatabaseAWSAccountID, - ExternalID: clf.DatabaseAWSExternalID, + Region: clf.DatabaseAWSRegion, + AccountID: clf.DatabaseAWSAccountID, + AssumeRoleARN: clf.DatabaseAWSAssumeRoleARN, + ExternalID: clf.DatabaseAWSExternalID, Redshift: servicecfg.DatabaseAWSRedshift{ ClusterID: clf.DatabaseAWSRedshiftClusterID, }, diff --git a/lib/config/configuration_test.go b/lib/config/configuration_test.go index 0f22409434184..c6b21305483a4 100644 --- a/lib/config/configuration_test.go +++ b/lib/config/configuration_test.go @@ -354,6 +354,8 @@ func TestConfigReading(t *testing.T) { Tags: map[string]apiutils.Strings{ "a": {"b"}, }, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", InstallParams: &InstallParams{ JoinParams: JoinParams{ TokenName: "aws-discovery-iam-token", @@ -471,6 +473,8 @@ func TestConfigReading(t *testing.T) { Tags: map[string]apiutils.Strings{ "a": {"b"}, }, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, { Types: []string{"rds"}, @@ -478,6 +482,7 @@ func TestConfigReading(t *testing.T) { Tags: map[string]apiutils.Strings{ "c": {"d"}, }, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", }, }, AzureMatchers: []AzureMatcher{ @@ -828,6 +833,17 @@ SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7 }, }, })) + require.Empty(t, cmp.Diff(cfg.Databases.AWSMatchers, + []services.AWSMatcher{ + { + Types: []string{"rds"}, + Regions: []string{"us-west-1"}, + AssumeRole: services.AssumeRole{ + RoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", + }, + }, + })) require.True(t, cfg.Kube.Enabled) require.Empty(t, cmp.Diff(cfg.Kube.ResourceMatchers, @@ -849,6 +865,8 @@ SREzU8onbBsjMg9QDiSf5oJLKvd/Ren+zGY7 require.True(t, cfg.Discovery.Enabled) require.Equal(t, cfg.Discovery.AWSMatchers[0].Regions, []string{"eu-central-1"}) require.Equal(t, cfg.Discovery.AWSMatchers[0].Types, []string{"ec2"}) + require.Equal(t, cfg.Discovery.AWSMatchers[0].AssumeRole.RoleARN, "arn:aws:iam::123456789012:role/DBDiscoverer") + require.Equal(t, cfg.Discovery.AWSMatchers[0].AssumeRole.ExternalID, "externalID123") require.Equal(t, cfg.Discovery.AWSMatchers[0].Params, services.InstallerParams{ InstallTeleport: true, JoinMethod: "iam", @@ -1421,9 +1439,11 @@ func makeConfigFixture() string { conf.Discovery.EnabledFlag = "true" conf.Discovery.AWSMatchers = []AWSMatcher{ { - Types: []string{"ec2"}, - Regions: []string{"us-west-1", "us-east-1"}, - Tags: map[string]apiutils.Strings{"a": {"b"}}, + Types: []string{"ec2"}, + Regions: []string{"us-west-1", "us-east-1"}, + Tags: map[string]apiutils.Strings{"a": {"b"}}, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, } @@ -1517,14 +1537,17 @@ func makeConfigFixture() string { } conf.Databases.AWSMatchers = []AWSMatcher{ { - Types: []string{"rds"}, - Regions: []string{"us-west-1", "us-east-1"}, - Tags: map[string]apiutils.Strings{"a": {"b"}}, + Types: []string{"rds"}, + Regions: []string{"us-west-1", "us-east-1"}, + Tags: map[string]apiutils.Strings{"a": {"b"}}, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, { - Types: []string{"rds"}, - Regions: []string{"us-central-1"}, - Tags: map[string]apiutils.Strings{"c": {"d"}}, + Types: []string{"rds"}, + Regions: []string{"us-central-1"}, + Tags: map[string]apiutils.Strings{"c": {"d"}}, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", }, } conf.Databases.AzureMatchers = []AzureMatcher{ @@ -2605,17 +2628,23 @@ func TestDatabaseCLIFlags(t *testing.T) { { desc: "RDS database", inFlags: CommandLineFlags{ - DatabaseName: "rds", - DatabaseProtocol: defaults.ProtocolMySQL, - DatabaseURI: "localhost:3306", - DatabaseAWSRegion: "us-east-1", + DatabaseName: "rds", + DatabaseProtocol: defaults.ProtocolMySQL, + DatabaseURI: "localhost:3306", + DatabaseAWSRegion: "us-east-1", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, outDatabase: servicecfg.Database{ Name: "rds", Protocol: defaults.ProtocolMySQL, URI: "localhost:3306", AWS: servicecfg.DatabaseAWS{ - Region: "us-east-1", + Region: "us-east-1", + AccountID: "123456789012", // this gets derived from the assumed role. + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, StaticLabels: map[string]string{ types.OriginLabel: types.OriginConfigFile, @@ -2634,6 +2663,8 @@ func TestDatabaseCLIFlags(t *testing.T) { DatabaseURI: "localhost:5432", DatabaseAWSRegion: "us-east-1", DatabaseAWSRedshiftClusterID: "redshift-cluster-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, outDatabase: servicecfg.Database{ Name: "redshift", @@ -2644,6 +2675,9 @@ func TestDatabaseCLIFlags(t *testing.T) { Redshift: servicecfg.DatabaseAWSRedshift{ ClusterID: "redshift-cluster-1", }, + AccountID: "123456789012", // this gets derived from the assumed role. + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, StaticLabels: map[string]string{ types.OriginLabel: types.OriginConfigFile, @@ -2738,19 +2772,23 @@ func TestDatabaseCLIFlags(t *testing.T) { { desc: "AWS Keyspaces", inFlags: CommandLineFlags{ - DatabaseName: "keyspace", - DatabaseProtocol: defaults.ProtocolCassandra, - DatabaseURI: "cassandra.us-east-1.amazonaws.com:9142", - DatabaseAWSAccountID: "123456789012", - DatabaseAWSRegion: "us-east-1", + DatabaseName: "keyspace", + DatabaseProtocol: defaults.ProtocolCassandra, + DatabaseURI: "cassandra.us-east-1.amazonaws.com:9142", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSRegion: "us-east-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, outDatabase: servicecfg.Database{ Name: "keyspace", Protocol: defaults.ProtocolCassandra, URI: "cassandra.us-east-1.amazonaws.com:9142", AWS: servicecfg.DatabaseAWS{ - Region: "us-east-1", - AccountID: "123456789012", + Region: "us-east-1", + AccountID: "123456789012", + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, StaticLabels: map[string]string{ types.OriginLabel: types.OriginConfigFile, @@ -2764,21 +2802,23 @@ func TestDatabaseCLIFlags(t *testing.T) { { desc: "AWS DynamoDB", inFlags: CommandLineFlags{ - DatabaseName: "ddb", - DatabaseProtocol: defaults.ProtocolDynamoDB, - DatabaseURI: "dynamodb.us-east-1.amazonaws.com", - DatabaseAWSAccountID: "123456789012", - DatabaseAWSExternalID: "12345678901234", - DatabaseAWSRegion: "us-east-1", + DatabaseName: "ddb", + DatabaseProtocol: defaults.ProtocolDynamoDB, + DatabaseURI: "dynamodb.us-east-1.amazonaws.com", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSRegion: "us-east-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, outDatabase: servicecfg.Database{ Name: "ddb", Protocol: defaults.ProtocolDynamoDB, URI: "dynamodb.us-east-1.amazonaws.com", AWS: servicecfg.DatabaseAWS{ - Region: "us-east-1", - AccountID: "123456789012", - ExternalID: "12345678901234", + Region: "us-east-1", + AccountID: "123456789012", + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, StaticLabels: map[string]string{ types.OriginLabel: types.OriginConfigFile, diff --git a/lib/config/database.go b/lib/config/database.go index a9904244cd2a0..ff82d0b7e87a0 100644 --- a/lib/config/database.go +++ b/lib/config/database.go @@ -333,7 +333,7 @@ db_service: tls: ca_cert_file: "{{ .DatabaseCACertFile }}" {{- end }} - {{- if or .DatabaseAWSRegion .DatabaseAWSAccountID .DatabaseAWSExternalID .DatabaseAWSRedshiftClusterID .DatabaseAWSRDSInstanceID .DatabaseAWSRDSClusterID .DatabaseAWSElastiCacheGroupID .DatabaseAWSMemoryDBClusterName }} + {{- if or .DatabaseAWSRegion .DatabaseAWSAccountID .DatabaseAWSAssumeRoleARN .DatabaseAWSExternalID .DatabaseAWSRedshiftClusterID .DatabaseAWSRDSInstanceID .DatabaseAWSRDSClusterID .DatabaseAWSElastiCacheGroupID .DatabaseAWSMemoryDBClusterName }} aws: {{- if .DatabaseAWSRegion }} region: "{{ .DatabaseAWSRegion }}" @@ -341,6 +341,9 @@ db_service: {{- if .DatabaseAWSAccountID }} account_id: "{{ .DatabaseAWSAccountID }}" {{- end }} + {{- if .DatabaseAWSAssumeRoleARN }} + assume_role_arn: "{{ .DatabaseAWSAssumeRoleARN }}" + {{- end }} {{- if .DatabaseAWSExternalID }} external_id: "{{ .DatabaseAWSExternalID }}" {{- end }} @@ -580,6 +583,8 @@ type DatabaseSampleFlags struct { DatabaseAWSRegion string // DatabaseAWSAccountID is an optional AWS account ID e.g. when using Keyspaces or DynamoDB. DatabaseAWSAccountID string + // DatabaseAWSAssumeRoleARN is an optional AWS IAM role ARN to assume when accessing the database. + DatabaseAWSAssumeRoleARN string // DatabaseAWSExternalID is an optional AWS database external ID, used when assuming roles. DatabaseAWSExternalID string // DatabaseAWSRedshiftClusterID is Redshift cluster identifier. diff --git a/lib/config/database_test.go b/lib/config/database_test.go index 2b0f333a84638..68c6ba015cd3c 100644 --- a/lib/config/database_test.go +++ b/lib/config/database_test.go @@ -232,25 +232,27 @@ func TestMakeDatabaseConfig(t *testing.T) { }, "AWSKeyspaces": { flags: DatabaseSampleFlags{ - StaticDatabaseName: "sample", - StaticDatabaseProtocol: "cassandra", - StaticDatabaseURI: "cassandra.us-west-1.amazonaws.com", - DatabaseCACertFile: pemfile, - DatabaseAWSRegion: "us-west-1", - DatabaseAWSAccountID: "123456789012", - DatabaseAWSExternalID: "1234567890", + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "cassandra", + StaticDatabaseURI: "cassandra.us-west-1.amazonaws.com", + DatabaseCACertFile: pemfile, + DatabaseAWSRegion: "us-west-1", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, "AWSKeyspacesDeriveURIFromAWSRegion": { flags: DatabaseSampleFlags{ - StaticDatabaseName: "sample", - StaticDatabaseProtocol: "cassandra", - StaticDatabaseURI: "", - DatabaseCACertFile: pemfile, - DatabaseAWSRegion: "us-west-1", - DatabaseAWSAccountID: "123456789012", - DatabaseAWSExternalID: "1234567890", + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "cassandra", + StaticDatabaseURI: "", + DatabaseCACertFile: pemfile, + DatabaseAWSRegion: "us-west-1", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -261,6 +263,8 @@ func TestMakeDatabaseConfig(t *testing.T) { StaticDatabaseURI: "redshift-cluster-1.abcdefghijklmnop.us-west-1.redshift.amazonaws.com:5439", DatabaseAWSRegion: "us-west-1", DatabaseAWSRedshiftClusterID: "redshift-cluster-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -271,16 +275,20 @@ func TestMakeDatabaseConfig(t *testing.T) { StaticDatabaseURI: "rds-instance-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", DatabaseAWSRegion: "us-west-1", DatabaseAWSRDSInstanceID: "rsd-instance-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, "AWSRDSCluster": { flags: DatabaseSampleFlags{ - StaticDatabaseName: "sample", - StaticDatabaseProtocol: "postgres", - StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", - DatabaseAWSRegion: "us-west-1", - DatabaseAWSRDSClusterID: "aurora-cluster-1", + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "postgres", + StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSRDSClusterID: "aurora-cluster-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -291,6 +299,8 @@ func TestMakeDatabaseConfig(t *testing.T) { StaticDatabaseURI: "clustercfg.my-memorydb.xxxxxx.memorydb.us-east-1.amazonaws.com:6379", DatabaseAWSRegion: "us-west-1", DatabaseAWSMemoryDBClusterName: "my-memorydb", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -301,6 +311,8 @@ func TestMakeDatabaseConfig(t *testing.T) { StaticDatabaseURI: "master.redis-cluster-example.abcdef.usw1.cache.amazonaws.com:6379", DatabaseAWSRegion: "us-west-1", DatabaseAWSElastiCacheGroupID: "redis-cluster-example", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -327,10 +339,12 @@ func TestMakeDatabaseConfig(t *testing.T) { }, "DynamoDBDeriveURIFromAWSRegion": { flags: DatabaseSampleFlags{ - StaticDatabaseName: "sample", - StaticDatabaseProtocol: "dynamodb", - DatabaseAWSAccountID: "123456789012", - DatabaseAWSRegion: "us-west-1", + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "dynamodb", + DatabaseAWSAccountID: "123456789012", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + DatabaseAWSExternalID: "externalID123", }, requireFn: require.NoError, }, @@ -385,6 +399,54 @@ func TestMakeDatabaseConfig(t *testing.T) { }, requireFn: require.Error, }, + "AWSExternalIDMissingAWSRoleARN": { + flags: DatabaseSampleFlags{ + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "postgres", + StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSRDSClusterID: "aurora-cluster-1", + DatabaseAWSAssumeRoleARN: "", // missing role arn raises error because external id is set. + DatabaseAWSExternalID: "externalID123", + }, + requireFn: require.Error, + }, + "MissingAWSRoleARNName": { + flags: DatabaseSampleFlags{ + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "postgres", + StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSRDSClusterID: "aurora-cluster-1", + DatabaseAWSAssumeRoleARN: "arn:aws:iam::123456789012:role", // missing role name + DatabaseAWSExternalID: "externalID123", + }, + requireFn: require.Error, + }, + "InvalidAWSRoleARNFormat": { + flags: DatabaseSampleFlags{ + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "postgres", + StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSRDSClusterID: "aurora-cluster-1", + DatabaseAWSAssumeRoleARN: "foobar", + DatabaseAWSExternalID: "externalID123", + }, + requireFn: require.Error, + }, + "InvalidAWSRoleARNResourceService": { + flags: DatabaseSampleFlags{ + StaticDatabaseName: "sample", + StaticDatabaseProtocol: "postgres", + StaticDatabaseURI: "aurora-cluster-1.abcdefghijklmnop.us-west-1.rds.amazonaws.com:5432", + DatabaseAWSRegion: "us-west-1", + DatabaseAWSRDSClusterID: "aurora-cluster-1", + DatabaseAWSAssumeRoleARN: "arn:aws:sts::123456789012:federated-user/Alice", // sts != iam + DatabaseAWSExternalID: "externalID123", + }, + requireFn: require.Error, + }, } for name, tt := range tests { diff --git a/lib/config/fileconf.go b/lib/config/fileconf.go index 2f56334928a17..837873fab0442 100644 --- a/lib/config/fileconf.go +++ b/lib/config/fileconf.go @@ -53,6 +53,7 @@ import ( "github.com/gravitational/teleport/lib/services" "github.com/gravitational/teleport/lib/sshutils/x11" "github.com/gravitational/teleport/lib/utils" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // FileConfig structure represents the teleport configuration stored in a config file @@ -507,6 +508,16 @@ func checkAndSetDefaultsForAWSMatchers(matcherInput []AWSMatcher) error { } } + if matcher.AssumeRoleARN != "" { + _, err := awsutils.ParseRoleARN(matcher.AssumeRoleARN) + if err != nil { + return trace.Wrap(err, "discovery service AWS matcher assume_role_arn is invalid") + } + } else if matcher.ExternalID != "" { + return trace.BadParameter("discovery service AWS matcher assume_role_arn is empty, but has external_id %q", + matcher.ExternalID) + } + if matcher.Tags == nil || len(matcher.Tags) == 0 { matcher.Tags = map[string]apiutils.Strings{types.Wildcard: {types.Wildcard}} } @@ -1625,6 +1636,11 @@ type AWSMatcher struct { Types []string `yaml:"types,omitempty"` // Regions are AWS regions to query for databases. Regions []string `yaml:"regions,omitempty"` + // AssumeRoleARN is the AWS role to assume for database discovery. + AssumeRoleARN string `yaml:"assume_role_arn,omitempty"` + // ExternalID is the AWS external ID to use when assuming a role for + // database discovery in an external AWS account. + ExternalID string `yaml:"external_id,omitempty"` // Tags are AWS tags to match. Tags map[string]apiutils.Strings `yaml:"tags,omitempty"` // InstallParams sets the join method when installing on @@ -1783,6 +1799,8 @@ type DatabaseAWS struct { MemoryDB DatabaseAWSMemoryDB `yaml:"memorydb"` // AccountID is the AWS account ID. AccountID string `yaml:"account_id,omitempty"` + // AssumeRoleARN is the AWS role to assume to before accessing the database. + AssumeRoleARN string `yaml:"assume_role_arn,omitempty"` // ExternalID is an optional AWS external ID used to enable assuming an AWS role across accounts. ExternalID string `yaml:"external_id,omitempty"` // RedshiftServerless contains RedshiftServerless specific settings. diff --git a/lib/config/fileconf_test.go b/lib/config/fileconf_test.go index 1b5d77bbae43b..96778adcf3ad6 100644 --- a/lib/config/fileconf_test.go +++ b/lib/config/fileconf_test.go @@ -922,6 +922,8 @@ func TestDiscoveryConfig(t *testing.T) { "ssm": cfgMap{ "document_name": "hello_document", }, + "assume_role_arn": "arn:aws:iam::123456789012:role/DBDiscoverer", + "external_id": "externalID123", }, } }, @@ -941,7 +943,9 @@ func TestDiscoveryConfig(t *testing.T) { SSHDConfig: "/etc/ssh/sshd_config", ScriptName: "installer-custom", }, - SSM: AWSSSM{DocumentName: "hello_document"}, + SSM: AWSSSM{DocumentName: "hello_document"}, + AssumeRoleARN: "arn:aws:iam::123456789012:role/DBDiscoverer", + ExternalID: "externalID123", }, }, }, @@ -983,6 +987,64 @@ func TestDiscoveryConfig(t *testing.T) { }, expectedDiscoverySection: Discovery{}, }, + { + desc: "AWS section is filled with external_id but empty assume_role_arn", + expectError: require.Error, + expectEnabled: require.True, + mutate: func(cfg cfgMap) { + cfg["discovery_service"].(cfgMap)["enabled"] = "yes" + cfg["discovery_service"].(cfgMap)["aws"] = []cfgMap{ + { + "types": []string{"rds"}, + "regions": []string{"us-west-1"}, + "assume_role_arn": "", + "external_id": "externalid123", + "tags": cfgMap{ + "discover_teleport": "yes", + }, + }, + } + }, + expectedDiscoverySection: Discovery{}, + }, + { + desc: "AWS section is filled with invalid assume_role_arn", + expectError: require.Error, + expectEnabled: require.True, + mutate: func(cfg cfgMap) { + cfg["discovery_service"].(cfgMap)["enabled"] = "yes" + cfg["discovery_service"].(cfgMap)["aws"] = []cfgMap{ + { + "types": []string{"rds"}, + "regions": []string{"us-west-1"}, + "assume_role_arn": "foobar", + "tags": cfgMap{ + "discover_teleport": "yes", + }, + }, + } + }, + expectedDiscoverySection: Discovery{}, + }, + { + desc: "AWS section is filled with assume_role_arn that is not an iam ARN", + expectError: require.Error, + expectEnabled: require.True, + mutate: func(cfg cfgMap) { + cfg["discovery_service"].(cfgMap)["enabled"] = "yes" + cfg["discovery_service"].(cfgMap)["aws"] = []cfgMap{ + { + "types": []string{"rds"}, + "regions": []string{"us-west-1"}, + "assume_role_arn": "arn:aws:sts::123456789012:federated-user/Alice", + "tags": cfgMap{ + "discover_teleport": "yes", + }, + }, + } + }, + expectedDiscoverySection: Discovery{}, + }, { desc: "AWS section is filled with no token", expectError: require.NoError, diff --git a/lib/config/testdata_test.go b/lib/config/testdata_test.go index 6fc7f2c6a7d00..eb5caa39c390d 100644 --- a/lib/config/testdata_test.go +++ b/lib/config/testdata_test.go @@ -180,6 +180,11 @@ db_service: regions: ["westus"] tags: "c": "d" + aws: + - types: ["rds"] + regions: ["us-west-1"] + assume_role_arn: "arn:aws:iam::123456789012:role/DBDiscoverer" + external_id: "externalID123" kubernetes_service: enabled: yes @@ -195,6 +200,8 @@ discovery_service: aws: - types: ["ec2"] regions: ["eu-central-1"] + assume_role_arn: "arn:aws:iam::123456789012:role/DBDiscoverer" + external_id: "externalID123" ` // NoServicesConfigString is a configuration file with no services enabled diff --git a/lib/service/servicecfg/database.go b/lib/service/servicecfg/database.go index ef5471aec6d6b..cdaadaf48f976 100644 --- a/lib/service/servicecfg/database.go +++ b/lib/service/servicecfg/database.go @@ -24,6 +24,7 @@ import ( "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/limiter" "github.com/gravitational/teleport/lib/services" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // DatabasesConfig configures the database proxy service. @@ -99,6 +100,17 @@ func (d *Database) CheckAndSetDefaults() error { } } + // if AWS account id is missing, but assume role arn is given, + // try to parse the role arn and set the account id to match. + if d.AWS.AccountID == "" && d.AWS.AssumeRoleARN != "" { + parsed, err := awsutils.ParseRoleARN(d.AWS.AssumeRoleARN) + if err != nil { + return trace.BadParameter("database %q invalid AWS assume_role_arn: %v", + d.Name, err) + } + d.AWS.AccountID = parsed.AccountID + } + // Do a test run with extra validations. db, err := d.ToDatabase() if err != nil { @@ -126,9 +138,10 @@ func (d *Database) ToDatabase() (types.Database, error) { ServerVersion: d.MySQL.ServerVersion, }, AWS: types.AWS{ - AccountID: d.AWS.AccountID, - ExternalID: d.AWS.ExternalID, - Region: d.AWS.Region, + AccountID: d.AWS.AccountID, + AssumeRoleARN: d.AWS.AssumeRoleARN, + ExternalID: d.AWS.ExternalID, + Region: d.AWS.Region, Redshift: types.Redshift{ ClusterID: d.AWS.Redshift.ClusterID, }, @@ -204,6 +217,8 @@ type DatabaseAWS struct { SecretStore DatabaseAWSSecretStore // AccountID is the AWS account ID. AccountID string + // AssumeRoleARN is the AWS role to assume to before accessing the database. + AssumeRoleARN string // ExternalID is an optional AWS external ID used to enable assuming an AWS role across accounts. ExternalID string // RedshiftServerless contains AWS Redshift Serverless specific settings. diff --git a/lib/services/database.go b/lib/services/database.go index 2f9e595efbaba..b9908caf7488a 100644 --- a/lib/services/database.go +++ b/lib/services/database.go @@ -45,7 +45,7 @@ import ( "k8s.io/apimachinery/pkg/util/validation" "github.com/gravitational/teleport/api/types" - awsutils "github.com/gravitational/teleport/api/utils/aws" + apiawsutils "github.com/gravitational/teleport/api/utils/aws" azureutils "github.com/gravitational/teleport/api/utils/azure" libcloudaws "github.com/gravitational/teleport/lib/cloud/aws" "github.com/gravitational/teleport/lib/cloud/azure" @@ -53,6 +53,7 @@ import ( "github.com/gravitational/teleport/lib/srv/db/redis/connection" "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" + awsutils "github.com/gravitational/teleport/lib/utils/aws" ) // DatabaseGetter defines interface for fetching database resources. @@ -195,6 +196,23 @@ func ValidateDatabase(db types.Database) error { return trace.BadParameter("missing service principal name for database %q", db.GetName()) } } + + awsMeta := db.GetAWS() + if awsMeta.AssumeRoleARN != "" { + if awsMeta.AccountID == "" { + return trace.BadParameter("database %q missing AWS account ID", db.GetName()) + } + parsed, err := awsutils.ParseRoleARN(awsMeta.AssumeRoleARN) + if err != nil { + return trace.BadParameter("database %q assume_role_arn %q is invalid: %v", + db.GetName(), awsMeta.AssumeRoleARN, err) + } + err = awsutils.CheckARNPartitionAndAccount(parsed, awsMeta.Partition(), awsMeta.AccountID) + if err != nil { + return trace.BadParameter("database %q is incompatible with AWS assume_role_arn %q: %v", + db.GetName(), awsMeta.AssumeRoleARN, err) + } + } return nil } @@ -586,7 +604,7 @@ func NewDatabasesFromRDSClusterCustomEndpoints(cluster *rds.DBCluster) (types.Da for _, endpoint := range cluster.CustomEndpoints { // RDS custom endpoint format: // .cluster-custom-. - endpointDetails, err := awsutils.ParseRDSEndpoint(aws.StringValue(endpoint)) + endpointDetails, err := apiawsutils.ParseRDSEndpoint(aws.StringValue(endpoint)) if err != nil { errors = append(errors, trace.Wrap(err)) continue @@ -713,7 +731,7 @@ func NewDatabaseFromElastiCacheConfigurationEndpoint(cluster *elasticache.Replic return nil, trace.BadParameter("missing configuration endpoint") } - return newElastiCacheDatabase(cluster, cluster.ConfigurationEndpoint, awsutils.ElastiCacheConfigurationEndpoint, extraLabels) + return newElastiCacheDatabase(cluster, cluster.ConfigurationEndpoint, apiawsutils.ElastiCacheConfigurationEndpoint, extraLabels) } // NewDatabasesFromElastiCacheNodeGroups creates database resources from @@ -722,7 +740,7 @@ func NewDatabasesFromElastiCacheNodeGroups(cluster *elasticache.ReplicationGroup var databases types.Databases for _, nodeGroup := range cluster.NodeGroups { if nodeGroup.PrimaryEndpoint != nil { - database, err := newElastiCacheDatabase(cluster, nodeGroup.PrimaryEndpoint, awsutils.ElastiCachePrimaryEndpoint, extraLabels) + database, err := newElastiCacheDatabase(cluster, nodeGroup.PrimaryEndpoint, apiawsutils.ElastiCachePrimaryEndpoint, extraLabels) if err != nil { return nil, trace.Wrap(err) } @@ -730,7 +748,7 @@ func NewDatabasesFromElastiCacheNodeGroups(cluster *elasticache.ReplicationGroup } if nodeGroup.ReaderEndpoint != nil { - database, err := newElastiCacheDatabase(cluster, nodeGroup.ReaderEndpoint, awsutils.ElastiCacheReaderEndpoint, extraLabels) + database, err := newElastiCacheDatabase(cluster, nodeGroup.ReaderEndpoint, apiawsutils.ElastiCacheReaderEndpoint, extraLabels) if err != nil { return nil, trace.Wrap(err) } @@ -748,7 +766,7 @@ func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *ela } suffix := make([]string, 0) - if endpointType == awsutils.ElastiCacheReaderEndpoint { + if endpointType == apiawsutils.ElastiCacheReaderEndpoint { suffix = []string{endpointType} } @@ -765,7 +783,7 @@ func newElastiCacheDatabase(cluster *elasticache.ReplicationGroup, endpoint *ela // NewDatabaseFromMemoryDBCluster creates a database resource from a MemoryDB // cluster. func NewDatabaseFromMemoryDBCluster(cluster *memorydb.Cluster, extraLabels map[string]string) (types.Database, error) { - endpointType := awsutils.MemoryDBClusterEndpoint + endpointType := apiawsutils.MemoryDBClusterEndpoint metadata, err := MetadataFromMemoryDBCluster(cluster, endpointType) if err != nil { diff --git a/lib/services/database_test.go b/lib/services/database_test.go index eaa48fbbdcf80..57cd93f838521 100644 --- a/lib/services/database_test.go +++ b/lib/services/database_test.go @@ -126,6 +126,42 @@ func TestValidateDatabase(t *testing.T) { }, expectError: true, }, + { + inputName: "invalid-database-assume-role-arn", + inputSpec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolDynamoDB, + AWS: types.AWS{ + Region: "us-east-1", + AccountID: "123456789012", + AssumeRoleARN: "foobar", + }, + }, + expectError: true, + }, + { + inputName: "invalid-database-assume-role-arn-resource-type", + inputSpec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolDynamoDB, + AWS: types.AWS{ + Region: "us-east-1", + AccountID: "123456789012", + AssumeRoleARN: "arn:aws:sts::123456789012:federated-user/Alice", + }, + }, + expectError: true, + }, + { + inputName: "invalid-database-assume-role-arn-account-id-mismatch", + inputSpec: types.DatabaseSpecV3{ + Protocol: defaults.ProtocolDynamoDB, + AWS: types.AWS{ + Region: "us-east-1", + AccountID: "123456789012", + AssumeRoleARN: "arn:aws:iam::111222333444:federated-user/Alice", + }, + }, + expectError: true, + }, { inputName: "invalid-database-CA-cert", inputSpec: types.DatabaseSpecV3{ diff --git a/lib/services/matchers.go b/lib/services/matchers.go index 21e1f56b43cd7..fd33b626f1b8b 100644 --- a/lib/services/matchers.go +++ b/lib/services/matchers.go @@ -69,12 +69,38 @@ type InstallerParams struct { PublicProxyAddr string } +// AssumeRole provides a role ARN and ExternalID to assume an AWS role +// when interacting with AWS resources. +type AssumeRole struct { + // RoleARN is the fully specified AWS IAM role ARN. + RoleARN string + // ExternalID is the external ID used to assume a role in another account. + ExternalID string +} + +// IsEmpty is a helper function that returns whether the assume role info +// is empty. +func (a *AssumeRole) IsEmpty() bool { + return a.RoleARN == "" && a.ExternalID == "" +} + +// AssumeRoleFromAWSMetadata is a conversion helper function that extracts +// AWS IAM role ARN and external ID from AWS metadata. +func AssumeRoleFromAWSMetadata(meta *types.AWS) AssumeRole { + return AssumeRole{ + RoleARN: meta.AssumeRoleARN, + ExternalID: meta.ExternalID, + } +} + // AWSMatcher matches AWS databases. type AWSMatcher struct { // Types are AWS database types to match, "rds" or "redshift". Types []string // Regions are AWS regions to query for databases. Regions []string + // AssumeRole is the AWS role to assume when discovering AWS databases. + AssumeRole AssumeRole // Tags are AWS tags to match. Tags types.Labels // Params are passed to AWS when executing the SSM document diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 20540f436a07b..1f0c39e13db84 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -238,6 +238,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con dbStartCmd.Flag("ca-cert", "Database CA certificate path.").StringVar(&ccf.DatabaseCACertFile) dbStartCmd.Flag("aws-region", "(Only for RDS, Aurora, Redshift, ElastiCache or MemoryDB) AWS region AWS hosted database instance is running in.").StringVar(&ccf.DatabaseAWSRegion) dbStartCmd.Flag("aws-account-id", "(Only for Keyspaces or DynamoDB) AWS Account ID.").StringVar(&ccf.DatabaseAWSAccountID) + dbStartCmd.Flag("aws-assume-role-arn", "Optional AWS IAM role to assume.").StringVar(&ccf.DatabaseAWSAssumeRoleARN) dbStartCmd.Flag("aws-external-id", "Optional AWS external ID used when assuming an AWS role.").StringVar(&ccf.DatabaseAWSExternalID) dbStartCmd.Flag("aws-redshift-cluster-id", "(Only for Redshift) Redshift database cluster identifier.").StringVar(&ccf.DatabaseAWSRedshiftClusterID) dbStartCmd.Flag("aws-rds-instance-id", "(Only for RDS) RDS instance identifier.").StringVar(&ccf.DatabaseAWSRDSInstanceID) @@ -280,6 +281,7 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con dbConfigureCreate.Flag("labels", "Comma-separated list of labels for the database, for example env=dev,dept=it").StringVar(&dbConfigCreateFlags.StaticDatabaseRawLabels) dbConfigureCreate.Flag("aws-region", "(Only for AWS-hosted databases) AWS region RDS, Aurora, Redshift, Redshift Serverless, ElastiCache, or MemoryDB database instance is running in.").StringVar(&dbConfigCreateFlags.DatabaseAWSRegion) dbConfigureCreate.Flag("aws-account-id", "(Only for Keyspaces or DynamoDB) AWS Account ID.").StringVar(&dbConfigCreateFlags.DatabaseAWSAccountID) + dbConfigureCreate.Flag("aws-assume-role-arn", "Optional AWS IAM role to assume.").StringVar(&dbConfigCreateFlags.DatabaseAWSAssumeRoleARN) dbConfigureCreate.Flag("aws-external-id", "(Only for AWS-hosted databases) Optional AWS external ID to use when assuming AWS roles.").StringVar(&dbConfigCreateFlags.DatabaseAWSExternalID) dbConfigureCreate.Flag("aws-redshift-cluster-id", "(Only for Redshift) Redshift database cluster identifier.").StringVar(&dbConfigCreateFlags.DatabaseAWSRedshiftClusterID) dbConfigureCreate.Flag("aws-rds-cluster-id", "(Only for RDS Aurora) RDS Aurora database cluster identifier.").StringVar(&dbConfigCreateFlags.DatabaseAWSRDSClusterID) From 85746c0e93c85108660cea6ea1ec854de3ee1f41 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 17 Mar 2023 20:47:31 -0700 Subject: [PATCH 3/4] Update lib/service/servicecfg/database.go Co-authored-by: STeve (Xin) Huang --- lib/service/servicecfg/database.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/service/servicecfg/database.go b/lib/service/servicecfg/database.go index cdaadaf48f976..91cc52e1292e3 100644 --- a/lib/service/servicecfg/database.go +++ b/lib/service/servicecfg/database.go @@ -100,7 +100,7 @@ func (d *Database) CheckAndSetDefaults() error { } } - // if AWS account id is missing, but assume role arn is given, + // If AWS account ID is missing, but assume role ARN is given, // try to parse the role arn and set the account id to match. if d.AWS.AccountID == "" && d.AWS.AssumeRoleARN != "" { parsed, err := awsutils.ParseRoleARN(d.AWS.AssumeRoleARN) From 4577ddbe6fa0f2b4d5ebbc5e3576fd024faa23b8 Mon Sep 17 00:00:00 2001 From: Gavin Frazar Date: Fri, 17 Mar 2023 21:35:49 -0700 Subject: [PATCH 4/4] fixup database checks * external_id does not require assume_role_arn for DynamoDB, Keyspaces, or Redshift Serverless * keyspaces gives a hint that URI could not be derived because region was also empty * make some error messages more consistent with each other regarding missing/empty fields. --- api/types/database.go | 29 ++++++++++++++++++----------- api/types/database_test.go | 29 ++++++++++++++++++----------- 2 files changed, 36 insertions(+), 22 deletions(-) diff --git a/api/types/database.go b/api/types/database.go index 7be101fe32afd..0155b200608a3 100644 --- a/api/types/database.go +++ b/api/types/database.go @@ -508,16 +508,21 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { return trace.BadParameter("database %q protocol is empty", d.GetName()) } if d.Spec.URI == "" { - switch { - case d.IsAWSKeyspaces() && d.Spec.AWS.Region != "": - // In case of AWS Hosted Cassandra allow to omit URI. - // The URL will be constructed from the database resource based on the region and account ID. - d.Spec.URI = awsutils.CassandraEndpointURLForRegion(d.Spec.AWS.Region) - case d.IsDynamoDB(): + switch d.GetType() { + case DatabaseTypeAWSKeyspaces: + if d.Spec.AWS.Region != "" { + // In case of AWS Hosted Cassandra allow to omit URI. + // The URL will be constructed from the database resource based on the region and account ID. + d.Spec.URI = awsutils.CassandraEndpointURLForRegion(d.Spec.AWS.Region) + } else { + return trace.BadParameter("AWS Keyspaces database %q URI is empty and cannot be derived without a configured AWS region", + d.GetName()) + } + case DatabaseTypeDynamoDB: if d.Spec.AWS.Region != "" { d.Spec.URI = awsutils.DynamoDBURIForRegion(d.Spec.AWS.Region) } else { - return trace.BadParameter("DynamoDB database %q URI is missing and cannot be derived from an empty configured AWS region", + return trace.BadParameter("DynamoDB database %q URI is empty and cannot be derived without a configured AWS region", d.GetName()) } default: @@ -679,8 +684,10 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { } } - if d.Spec.AWS.AssumeRoleARN == "" && d.Spec.AWS.ExternalID != "" { - return trace.BadParameter("AWS database %q has external_id %q, but assume_role_arn is missing", + if d.Spec.AWS.ExternalID != "" && d.Spec.AWS.AssumeRoleARN == "" && !d.RequireAWSIAMRolesAsUsers() { + // Databases that use database username to assume an IAM role do not + // need assume_role_arn in configuration when external_id is set. + return trace.BadParameter("AWS database %q has external_id %q, but assume_role_arn is empty", d.GetName(), d.Spec.AWS.ExternalID) } @@ -699,7 +706,7 @@ func (d *DatabaseV3) CheckAndSetDefaults() error { // handleDynamoDBConfig handles DynamoDB configuration checking. func (d *DatabaseV3) handleDynamoDBConfig() error { if d.Spec.AWS.AccountID == "" { - return trace.BadParameter("database %q AWS account ID is missing", d.GetName()) + return trace.BadParameter("database %q AWS account ID is empty", d.GetName()) } info, err := awsutils.ParseDynamoDBEndpoint(d.Spec.URI) @@ -709,7 +716,7 @@ func (d *DatabaseV3) handleDynamoDBConfig() error { // so we check if the region is configured to see if this is really a configuration error. if d.Spec.AWS.Region == "" { // the AWS region is empty and we can't derive it from the URI, so this is a config error. - return trace.BadParameter("database %q AWS region is missing and cannot be derived from the URI %q", + return trace.BadParameter("database %q AWS region is empty and cannot be derived from the URI %q", d.GetName(), d.Spec.URI) } if awsutils.IsAWSEndpoint(d.Spec.URI) { diff --git a/api/types/database_test.go b/api/types/database_test.go index 53cdf85f19003..c5febcfe9de84 100644 --- a/api/types/database_test.go +++ b/api/types/database_test.go @@ -644,6 +644,21 @@ func TestDynamoDBConfig(t *testing.T) { }, }, }, + { + desc: "configured external ID but not assume role is ok", + uri: "localhost:8080", + region: "us-west-1", + account: "123456789012", + externalID: "externalid123", + wantSpec: DatabaseSpecV3{ + URI: "localhost:8080", + AWS: AWS{ + Region: "us-west-1", + AccountID: "123456789012", + ExternalID: "externalid123", + }, + }, + }, { desc: "region and different AWS URI region is an error", uri: "dynamodb.us-west-2.amazonaws.com", @@ -662,12 +677,12 @@ func TestDynamoDBConfig(t *testing.T) { desc: "custom URI and missing region is an error", uri: "localhost:8080", account: "123456789012", - wantErrMsg: "region is missing", + wantErrMsg: "region is empty", }, { desc: "missing URI and missing region is an error", account: "123456789012", - wantErrMsg: "URI is missing", + wantErrMsg: "URI is empty", }, { desc: "invalid AWS account ID is an error", @@ -676,18 +691,10 @@ func TestDynamoDBConfig(t *testing.T) { account: "12345", wantErrMsg: "must be 12-digit", }, - { - desc: "configured external ID but not assume role is an error", - uri: "localhost:8080", - region: "us-west-1", - account: "123456789012", - externalID: "externalid123", - wantErrMsg: "assume_role_arn is missing", - }, { region: "us-west-1", desc: "missing account id", - wantErrMsg: "account ID is missing", + wantErrMsg: "account ID is empty", }, }