diff --git a/api/types/databaseservice.go b/api/types/databaseservice.go index e854a07bd0885..26810b2c77355 100644 --- a/api/types/databaseservice.go +++ b/api/types/databaseservice.go @@ -31,6 +31,8 @@ type DatabaseService interface { GetNamespace() string // GetResourceMatchers returns the resource matchers of the DatabaseService. + // Database services deployed by Teleport have known configurations where + // we will only define a single resource matcher. GetResourceMatchers() []*DatabaseResourceMatcher } diff --git a/lib/integrations/awsoidc/listdatabases.go b/lib/integrations/awsoidc/listdatabases.go index c169ef6cb3631..2bd7e959fea9c 100644 --- a/lib/integrations/awsoidc/listdatabases.go +++ b/lib/integrations/awsoidc/listdatabases.go @@ -103,6 +103,59 @@ func NewListDatabasesClient(ctx context.Context, req *AWSClientRequest) (ListDat return newRDSClient(ctx, req) } +// ListAllDatabases collects dbs until end of pages for all supported RDS engines and types. +func ListAllDatabases(ctx context.Context, clt ListDatabasesClient, region string) (*ListDatabasesResponse, error) { + fetchedRDSs := []types.Database{} + + // Get all rds instances. + nextToken := "" + for { + resp, err := ListDatabases(ctx, + clt, + ListDatabasesRequest{ + Region: region, + Engines: []string{services.RDSEngineMySQL, services.RDSEngineMariaDB, services.RDSEnginePostgres}, + RDSType: rdsTypeInstance, + NextToken: nextToken, + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + fetchedRDSs = append(fetchedRDSs, resp.Databases...) + nextToken = resp.NextToken + + if len(nextToken) == 0 { + break + } + } + + // Get all rds clusters. + nextToken = "" + for { + resp, err := ListDatabases(ctx, + clt, + ListDatabasesRequest{ + Region: region, + Engines: []string{services.RDSEngineAuroraMySQL, services.RDSEngineAuroraPostgres}, + RDSType: rdsTypeCluster, + NextToken: nextToken, + }, + ) + if err != nil { + return nil, trace.Wrap(err) + } + fetchedRDSs = append(fetchedRDSs, resp.Databases...) + nextToken = resp.NextToken + + if len(nextToken) == 0 { + break + } + } + + return &ListDatabasesResponse{Databases: fetchedRDSs}, nil +} + // ListDatabases calls the following AWS API: // https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBClusters.html // https://docs.aws.amazon.com/AmazonRDS/latest/APIReference/API_DescribeDBInstances.html diff --git a/lib/web/apiserver.go b/lib/web/apiserver.go index 89eb1df6291aa..6d8c7ea53ea2e 100644 --- a/lib/web/apiserver.go +++ b/lib/web/apiserver.go @@ -838,6 +838,7 @@ func (h *Handler) bindDefaultEndpoints() { h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/ec2ice", h.WithClusterAuth(h.awsOIDCListEC2ICE)) h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/deployec2ice", h.WithClusterAuth(h.awsOIDCDeployEC2ICE)) h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/securitygroups", h.WithClusterAuth(h.awsOIDCListSecurityGroups)) + h.POST("/webapi/sites/:site/integrations/aws-oidc/:name/requireddatabasesvpcs", h.WithClusterAuth(h.awsOIDCRequiredDatabasesVPCS)) h.GET("/webapi/scripts/integrations/configure/eice-iam.sh", h.WithLimiter(h.awsOIDCConfigureEICEIAM)) // AWS OIDC Integration specific endpoints: diff --git a/lib/web/integrations_awsoidc.go b/lib/web/integrations_awsoidc.go index 9485d62daaa11..227a3b43b284b 100644 --- a/lib/web/integrations_awsoidc.go +++ b/lib/web/integrations_awsoidc.go @@ -22,15 +22,20 @@ import ( "context" "fmt" "net/http" + "slices" "strings" "github.com/gravitational/trace" "github.com/julienschmidt/httprouter" "github.com/gravitational/teleport" + "github.com/gravitational/teleport/api/client" + "github.com/gravitational/teleport/api/client/proto" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/api/utils/aws" + "github.com/gravitational/teleport/lib/auth" + "github.com/gravitational/teleport/lib/defaults" "github.com/gravitational/teleport/lib/httplib" "github.com/gravitational/teleport/lib/integrations/awsoidc" "github.com/gravitational/teleport/lib/reversetunnelclient" @@ -372,6 +377,118 @@ func (h *Handler) awsOIDCListSecurityGroups(w http.ResponseWriter, r *http.Reque }, nil } +// awsOIDCRequiredDatabasesVPCS returns a map of required VPC's and its subnets. +// This is required during the web UI discover flow (where users opt for auto +// discovery) to determine if user can skip the auto deployment screen (where we deploy +// database agents). +// +// This api will return empty if we already have agents that can proxy the discovered databases. +// Otherwise it will return with a map of VPC and its subnets where it's values are later used +// to configure and deploy an agent (deploy an agent per unique VPC). +func (h *Handler) awsOIDCRequiredDatabasesVPCS(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) { + ctx := r.Context() + + var req ui.AWSOIDCRequiredVPCSRequest + if err := httplib.ReadJSON(r, &req); err != nil { + return nil, trace.Wrap(err) + } + + awsClientReq, err := h.awsOIDCClientRequest(ctx, req.Region, p, sctx, site) + if err != nil { + return nil, trace.Wrap(err) + } + + listDBsClient, err := awsoidc.NewListDatabasesClient(ctx, awsClientReq) + if err != nil { + return nil, trace.Wrap(err) + } + + clt, err := sctx.GetUserClient(ctx, site) + if err != nil { + return nil, trace.Wrap(err) + } + + resp, err := awsOIDCRequiredVPCSHelper(ctx, req, listDBsClient, clt) + if err != nil { + return nil, trace.Wrap(err) + } + return resp, nil +} + +func awsOIDCRequiredVPCSHelper(ctx context.Context, req ui.AWSOIDCRequiredVPCSRequest, listDBsClient awsoidc.ListDatabasesClient, clt auth.ClientI) (*ui.AWSOIDCRequiredVPCSResponse, error) { + resp, err := awsoidc.ListAllDatabases(ctx, listDBsClient, req.Region) + if err != nil { + return nil, trace.Wrap(err) + } + if len(resp.Databases) == 0 { + return nil, trace.BadParameter("there are no available RDS instances or clusters found in region %q", req.Region) + } + + // Get all database services with ecs/fargate metadata label. + nextToken := "" + fetchedDbSvcs := []types.DatabaseService{} + for { + page, err := client.GetResourcePage[types.DatabaseService](ctx, clt, &proto.ListResourcesRequest{ + ResourceType: types.KindDatabaseService, + Limit: defaults.MaxIterationLimit, + StartKey: nextToken, + Labels: map[string]string{types.AWSOIDCAgentLabel: types.True}, + }) + if err != nil { + return nil, trace.Wrap(err) + } + + fetchedDbSvcs = append(fetchedDbSvcs, page.Resources...) + nextToken = page.NextKey + if len(nextToken) == 0 { + break + } + } + + // Construct map of VPCs and its subnets. + vpcLookup := map[string][]string{} + for _, db := range resp.Databases { + rds := db.GetAWS().RDS + vpcId := rds.VPCID + if _, found := vpcLookup[vpcId]; !found { + vpcLookup[vpcId] = rds.Subnets + continue + } + combinedSubnets := append(vpcLookup[vpcId], rds.Subnets...) + vpcLookup[vpcId] = utils.Deduplicate(combinedSubnets) + } + + for _, svc := range fetchedDbSvcs { + if len(svc.GetResourceMatchers()) != 1 || svc.GetResourceMatchers()[0].Labels == nil { + continue + } + + // Database services deployed by Teleport have known configurations where + // we will only define a single resource matcher. + labelMatcher := *svc.GetResourceMatchers()[0].Labels + + // We check for length 3, because we are only + // wanting/checking for 3 discovery labels. + if len(labelMatcher) != 3 { + continue + } + if slices.Compare(labelMatcher[types.DiscoveryLabelAccountID], []string{req.AccountID}) != 0 { + continue + } + if slices.Compare(labelMatcher[types.DiscoveryLabelRegion], []string{req.Region}) != 0 { + continue + } + if len(labelMatcher[types.DiscoveryLabelVPCID]) != 1 { + continue + } + delete(vpcLookup, labelMatcher[types.DiscoveryLabelVPCID][0]) + } + + return &ui.AWSOIDCRequiredVPCSResponse{ + VPCMapOfSubnets: vpcLookup, + }, nil +} + // awsOIDCListEC2ICE returns a list of EC2 Instance Connect Endpoints using the ListEC2ICE action of the AWS OIDC Integration. func (h *Handler) awsOIDCListEC2ICE(w http.ResponseWriter, r *http.Request, p httprouter.Params, sctx *SessionContext, site reversetunnelclient.RemoteSite) (any, error) { ctx := r.Context() diff --git a/lib/web/integrations_awsoidc_test.go b/lib/web/integrations_awsoidc_test.go index 59fd97b77a24a..68c32af9baa95 100644 --- a/lib/web/integrations_awsoidc_test.go +++ b/lib/web/integrations_awsoidc_test.go @@ -24,8 +24,16 @@ import ( "net/url" "testing" + "github.com/aws/aws-sdk-go-v2/service/rds" + rdsTypes "github.com/aws/aws-sdk-go-v2/service/rds/types" + "github.com/aws/aws-sdk-go/aws" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/stretchr/testify/require" + + "github.com/gravitational/teleport/api/client/proto" + "github.com/gravitational/teleport/api/types" + "github.com/gravitational/teleport/lib/web/ui" ) func TestBuildDeployServiceConfigureIAMScript(t *testing.T) { @@ -336,6 +344,7 @@ func TestBuildAWSOIDCIdPConfigureScript(t *testing.T) { } func TestBuildListDatabasesConfigureIAMScript(t *testing.T) { + t.Parallel() isBadParamErrFn := func(tt require.TestingT, err error, i ...any) { require.True(tt, trace.IsBadParameter(err), "expected bad parameter, got %v", err) } @@ -421,3 +430,219 @@ func TestBuildListDatabasesConfigureIAMScript(t *testing.T) { }) } } + +func TestAWSOIDCRequiredVPCSHelper(t *testing.T) { + t.Parallel() + ctx := context.Background() + env := newWebPack(t, 1) + clt := env.proxies[0].client + + matchRegion := "us-east-1" + matchAccountId := "123456789012" + req := ui.AWSOIDCRequiredVPCSRequest{ + Region: matchRegion, + AccountID: matchAccountId, + } + + upsertDbSvcFn := func(vpcId string, matcher []*types.DatabaseResourceMatcher) { + if matcher == nil { + matcher = []*types.DatabaseResourceMatcher{ + { + Labels: &types.Labels{ + types.DiscoveryLabelAccountID: []string{matchAccountId}, + types.DiscoveryLabelRegion: []string{matchRegion}, + types.DiscoveryLabelVPCID: []string{vpcId}, + }, + }, + } + } + svc, err := types.NewDatabaseServiceV1(types.Metadata{ + Name: uuid.NewString(), + Labels: map[string]string{types.AWSOIDCAgentLabel: types.True}, + }, types.DatabaseServiceSpecV1{ + ResourceMatchers: matcher, + }) + require.NoError(t, err) + _, err = env.server.Auth().UpsertDatabaseService(ctx, svc) + require.NoError(t, err) + } + + extractKeysFn := func(resp *ui.AWSOIDCRequiredVPCSResponse) []string { + keys := make([]string, 0, len(resp.VPCMapOfSubnets)) + for k := range resp.VPCMapOfSubnets { + keys = append(keys, k) + } + return keys + } + + vpcs := []string{"vpc-1", "vpc-2", "vpc-3", "vpc-4", "vpc-5"} + rdss := []rdsTypes.DBInstance{} + for _, vpc := range vpcs { + rdss = append(rdss, rdsTypes.DBInstance{ + DBInstanceStatus: aws.String("available"), + DBInstanceIdentifier: aws.String(fmt.Sprintf("db-%v", vpc)), + DbiResourceId: aws.String("db-123"), + Engine: aws.String("postgres"), + DBInstanceArn: aws.String("arn:aws:iam::123456789012:role/MyARN"), + + Endpoint: &rdsTypes.Endpoint{ + Address: aws.String("endpoint.amazonaws.com"), + Port: aws.Int32(5432), + }, + DBSubnetGroup: &rdsTypes.DBSubnetGroup{ + Subnets: []rdsTypes.Subnet{{SubnetIdentifier: aws.String(fmt.Sprintf("subnet-for-%s", vpc))}}, + VpcId: aws.String(vpc), + }, + }) + } + + mockListClient := mockListDatabasesClient{dbInstances: rdss} + + // Double check we start with 0 db svcs. + s, err := env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{ + ResourceType: types.KindDatabaseService, + }) + require.NoError(t, err) + require.Empty(t, s.Resources) + + // All vpc's required. + resp, err := awsOIDCRequiredVPCSHelper(ctx, req, mockListClient, clt) + require.NoError(t, err) + require.Len(t, resp.VPCMapOfSubnets, 5) + require.ElementsMatch(t, vpcs, extractKeysFn(resp)) + + // Insert two valid database services. + upsertDbSvcFn("vpc-1", nil) + upsertDbSvcFn("vpc-5", nil) + + // Insert two invalid database services. + upsertDbSvcFn("vpc-2", []*types.DatabaseResourceMatcher{ + { + Labels: &types.Labels{ + types.DiscoveryLabelAccountID: []string{matchAccountId}, + types.DiscoveryLabelRegion: []string{"us-east-2"}, // not matching region + types.DiscoveryLabelVPCID: []string{"vpc-2"}, + }, + }, + }) + upsertDbSvcFn("vpc-2a", []*types.DatabaseResourceMatcher{ + { + Labels: &types.Labels{ + types.DiscoveryLabelAccountID: []string{matchAccountId}, + types.DiscoveryLabelRegion: []string{matchRegion}, + types.DiscoveryLabelVPCID: []string{"vpc-2"}, + "something": []string{"extra"}, // not matching b/c extra label + }, + }, + }) + + // Double check services were created. + s, err = env.server.Auth().ListResources(ctx, proto.ListResourcesRequest{ + ResourceType: types.KindDatabaseService, + }) + require.NoError(t, err) + require.Len(t, s.Resources, 4) + + // Test that only 3 vpcs are required. + resp, err = awsOIDCRequiredVPCSHelper(ctx, req, mockListClient, clt) + require.NoError(t, err) + require.ElementsMatch(t, []string{"vpc-2", "vpc-3", "vpc-4"}, extractKeysFn(resp)) + + // Insert the rest of db services + upsertDbSvcFn("vpc-2", nil) + upsertDbSvcFn("vpc-3", nil) + upsertDbSvcFn("vpc-4", nil) + + // Test no required vpcs. + resp, err = awsOIDCRequiredVPCSHelper(ctx, req, mockListClient, clt) + require.NoError(t, err) + require.Empty(t, resp.VPCMapOfSubnets) +} + +func TestAWSOIDCRequiredVPCSHelper_CombinedSubnetsForAVpcID(t *testing.T) { + ctx := context.Background() + env := newWebPack(t, 1) + clt := env.proxies[0].client + + rdss := []rdsTypes.DBInstance{ + { + DBInstanceStatus: aws.String("available"), + DBInstanceIdentifier: aws.String("id-vpc1"), + DbiResourceId: aws.String("db-123"), + Engine: aws.String("postgres"), + DBInstanceArn: aws.String("arn:aws:iam::123456789012:role/MyARN"), + + Endpoint: &rdsTypes.Endpoint{ + Address: aws.String("endpoint.amazonaws.com"), + Port: aws.Int32(5432), + }, + DBSubnetGroup: &rdsTypes.DBSubnetGroup{ + Subnets: []rdsTypes.Subnet{ + {SubnetIdentifier: aws.String("subnet1")}, + {SubnetIdentifier: aws.String("subnet2")}, + }, + VpcId: aws.String("vpc-1"), + }, + }, + { + DBInstanceStatus: aws.String("available"), + DBInstanceIdentifier: aws.String("id-vpc1a"), + DbiResourceId: aws.String("db-123"), + Engine: aws.String("postgres"), + DBInstanceArn: aws.String("arn:aws:iam::123456789012:role/MyARN"), + + Endpoint: &rdsTypes.Endpoint{ + Address: aws.String("endpoint.amazonaws.com"), + Port: aws.Int32(5432), + }, + DBSubnetGroup: &rdsTypes.DBSubnetGroup{ + Subnets: []rdsTypes.Subnet{ + {SubnetIdentifier: aws.String("subnet2")}, + {SubnetIdentifier: aws.String("subnet3")}, + {SubnetIdentifier: aws.String("subnet4")}, + {SubnetIdentifier: aws.String("subnet1")}, + }, + VpcId: aws.String("vpc-1"), + }, + }, + { + DBInstanceStatus: aws.String("available"), + DBInstanceIdentifier: aws.String("id-vpc2"), + DbiResourceId: aws.String("db-123"), + Engine: aws.String("postgres"), + DBInstanceArn: aws.String("arn:aws:iam::123456789012:role/MyARN"), + + Endpoint: &rdsTypes.Endpoint{ + Address: aws.String("endpoint.amazonaws.com"), + Port: aws.Int32(5432), + }, + DBSubnetGroup: &rdsTypes.DBSubnetGroup{ + Subnets: []rdsTypes.Subnet{{SubnetIdentifier: aws.String("subnet8")}}, + + VpcId: aws.String("vpc-2"), + }, + }, + } + + mockListClient := mockListDatabasesClient{dbInstances: rdss} + + resp, err := awsOIDCRequiredVPCSHelper(ctx, ui.AWSOIDCRequiredVPCSRequest{Region: "us-east-1"}, mockListClient, clt) + require.NoError(t, err) + require.Len(t, resp.VPCMapOfSubnets, 2) + require.ElementsMatch(t, []string{"subnet1", "subnet2", "subnet3", "subnet4"}, resp.VPCMapOfSubnets["vpc-1"]) + require.ElementsMatch(t, []string{"subnet8"}, resp.VPCMapOfSubnets["vpc-2"]) +} + +type mockListDatabasesClient struct { + dbInstances []rdsTypes.DBInstance +} + +func (m mockListDatabasesClient) DescribeDBInstances(ctx context.Context, params *rds.DescribeDBInstancesInput, optFns ...func(*rds.Options)) (*rds.DescribeDBInstancesOutput, error) { + return &rds.DescribeDBInstancesOutput{ + DBInstances: m.dbInstances, + }, nil +} + +func (m mockListDatabasesClient) DescribeDBClusters(ctx context.Context, params *rds.DescribeDBClustersInput, optFns ...func(*rds.Options)) (*rds.DescribeDBClustersOutput, error) { + return &rds.DescribeDBClustersOutput{}, nil +} diff --git a/lib/web/ui/integration.go b/lib/web/ui/integration.go index f2d91fee178f9..07d45130f098d 100644 --- a/lib/web/ui/integration.go +++ b/lib/web/ui/integration.go @@ -231,6 +231,21 @@ type AWSOIDCListSecurityGroupsResponse struct { NextToken string `json:"nextToken,omitempty"` } +// AWSOIDCRequiredVPCSRequest is a request to get required (missing) VPC's and its subnets. +type AWSOIDCRequiredVPCSRequest struct { + // Region is the AWS Region. + Region string `json:"region"` + // AccountID is the AWS Account ID. + AccountID string `json:"accountId"` +} + +// AWSOIDCRequiredVPCSResponse returns a list of required VPC's and its subnets. +type AWSOIDCRequiredVPCSResponse struct { + // VPCMapOfSubnets is a map of vpc ids and its subnets. + // Will be empty if no vpc's are required. + VPCMapOfSubnets map[string][]string `json:"vpcMapOfSubnets"` +} + // AWSOIDCListEC2ICERequest is a request to ListEC2ICEs using the AWS OIDC Integration. type AWSOIDCListEC2ICERequest struct { // Region is the AWS Region. diff --git a/web/packages/teleport/src/config.ts b/web/packages/teleport/src/config.ts index d799e2a1dbff1..0195874d061f8 100644 --- a/web/packages/teleport/src/config.ts +++ b/web/packages/teleport/src/config.ts @@ -268,6 +268,8 @@ const cfg = { awsConfigureIamScriptEc2InstanceConnectPath: '/v1/webapi/scripts/integrations/configure/eice-iam.sh?awsRegion=:region&role=:iamRoleName', + awsRdsDbRequiredVpcsPath: + '/v1/webapi/sites/:clusterId/integrations/aws-oidc/:name/requireddatabasesvpcs', awsRdsDbListPath: '/v1/webapi/sites/:clusterId/integrations/aws-oidc/:name/databases', awsDeployTeleportServicePath: @@ -782,6 +784,15 @@ const cfg = { }); }, + getAwsRdsDbRequiredVpcsUrl(integrationName: string) { + const clusterId = cfg.proxyCluster; + + return generatePath(cfg.api.awsRdsDbRequiredVpcsPath, { + clusterId, + name: integrationName, + }); + }, + getAwsDeployTeleportServiceUrl(integrationName: string) { const clusterId = cfg.proxyCluster; diff --git a/web/packages/teleport/src/services/integrations/integrations.ts b/web/packages/teleport/src/services/integrations/integrations.ts index 7808660ce073b..074db2e030d1d 100644 --- a/web/packages/teleport/src/services/integrations/integrations.ts +++ b/web/packages/teleport/src/services/integrations/integrations.ts @@ -79,8 +79,17 @@ export const integrationService = { return api.get(cfg.api.thumbprintPath); }, + fetchAwsRdsRequiredVpcs( + integrationName: string, + body: { region: string; accountId: string } + ): Promise> { + return api + .post(cfg.getAwsRdsDbRequiredVpcsUrl(integrationName), body) + .then(resp => resp.vpcMapOfSubnets); + }, + fetchAwsRdsDatabases( - integrationName, + integrationName: string, rdsEngineIdentifier: RdsEngineIdentifier, req: { region: AwsOidcListDatabasesRequest['region'];