Skip to content

Commit

Permalink
chore: proxy to RDS instances (aws#5453)
Browse files Browse the repository at this point in the history
By submitting this pull request, I confirm that you can use, modify, copy, and redistribute this contribution, under the Apache 2.0 License.
  • Loading branch information
dannyrandall authored Nov 8, 2023
1 parent 80d4794 commit 8f3adbd
Show file tree
Hide file tree
Showing 13 changed files with 545 additions and 77 deletions.
2 changes: 2 additions & 0 deletions internal/pkg/aws/resourcegroups/resourcegroups.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
const (
// ResourceTypeStateMachine is the resource type for the state machine of a job.
ResourceTypeStateMachine = "states:stateMachine"
// ResourceTypeRDS is the resource type for any rds resources.
ResourceTypeRDS = "rds"
)

type api interface {
Expand Down
117 changes: 115 additions & 2 deletions internal/pkg/cli/run_local.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,14 +20,17 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/request"
"github.com/aws/aws-sdk-go/aws/session"
sdkecs "github.com/aws/aws-sdk-go/service/ecs"
"github.com/aws/aws-sdk-go/service/rds"
sdksecretsmanager "github.com/aws/aws-sdk-go/service/secretsmanager"
sdkssm "github.com/aws/aws-sdk-go/service/ssm"
cmdtemplate "github.com/aws/copilot-cli/cmd/copilot/template"
"github.com/aws/copilot-cli/internal/pkg/aws/ecr"
awsecs "github.com/aws/copilot-cli/internal/pkg/aws/ecs"
"github.com/aws/copilot-cli/internal/pkg/aws/identity"
"github.com/aws/copilot-cli/internal/pkg/aws/resourcegroups"
"github.com/aws/copilot-cli/internal/pkg/aws/secretsmanager"
"github.com/aws/copilot-cli/internal/pkg/aws/sessions"
"github.com/aws/copilot-cli/internal/pkg/aws/ssm"
Expand Down Expand Up @@ -72,6 +75,15 @@ type hostFinder interface {
Hosts(context.Context) ([]orchestrator.Host, error)
}

type taggedResourceGetter interface {
GetResourcesByTags(string, map[string]string) ([]*resourcegroups.Resource, error)
}

type rdsDescriber interface {
DescribeDBInstancesPagesWithContext(context.Context, *rds.DescribeDBInstancesInput, func(*rds.DescribeDBInstancesOutput, bool) bool, ...request.Option) error
DescribeDBClustersPagesWithContext(context.Context, *rds.DescribeDBClustersInput, func(*rds.DescribeDBClustersOutput, bool) bool, ...request.Option) error
}

type recursiveWatcher interface {
Add(path string) error
Close() error
Expand Down Expand Up @@ -195,6 +207,8 @@ func newRunLocalOpts(vars runLocalVars) (*runLocalOpts, error) {
env: o.envName,
wkld: o.wkldName,
ecs: ecs.New(o.envManagerSess),
rg: resourcegroups.New(o.envManagerSess),
rds: rds.New(o.envManagerSess),
}
envDesc, err := describe.NewEnvDescriber(describe.NewEnvDescriberConfig{
App: o.appName,
Expand Down Expand Up @@ -427,7 +441,7 @@ func (o *runLocalOpts) getTask(ctx context.Context) (orchestrator.Task, error) {
}

if o.proxy {
pauseSecrets, err := sessionEnvVars(ctx, o.sess)
pauseSecrets, err := sessionEnvVars(ctx, o.envManagerSess)
if err != nil {
return orchestrator.Task{}, fmt.Errorf("get pause container secrets: %w", err)
}
Expand Down Expand Up @@ -793,6 +807,9 @@ type hostDiscoverer struct {
app string
env string
wkld string

rg taggedResourceGetter
rds rdsDescriber
}

func (h *hostDiscoverer) Hosts(ctx context.Context) ([]orchestrator.Host, error) {
Expand All @@ -815,12 +832,108 @@ func (h *hostDiscoverer) Hosts(ctx context.Context) ([]orchestrator.Host, error)
for _, alias := range sc.ClientAliases {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(alias.DnsName),
Port: strconv.Itoa(int(aws.Int64Value(alias.Port))),
Port: uint16(aws.Int64Value(alias.Port)),
})
}
}
}

rdsHosts, err := h.rdsHosts(ctx)
if err != nil {
return nil, fmt.Errorf("get rds hosts: %w", err)
}

return append(hosts, rdsHosts...), nil
}

// rdsHosts gets rds endpoints for workloads tagged for this workload
// or for the environment using direct AWS SDK calls.
func (h *hostDiscoverer) rdsHosts(ctx context.Context) ([]orchestrator.Host, error) {
var hosts []orchestrator.Host

resources, err := h.rg.GetResourcesByTags(resourcegroups.ResourceTypeRDS, map[string]string{
deploy.AppTagKey: h.app,
deploy.EnvTagKey: h.env,
})
switch {
case err != nil:
return nil, fmt.Errorf("get tagged resources: %w", err)
case len(resources) == 0:
return nil, nil
}

dbFilter := &rds.Filter{
Name: aws.String("db-instance-id"),
}
clusterFilter := &rds.Filter{
Name: aws.String("db-cluster-id"),
}
for i := range resources {
// we don't want resources that belong to other services
// but we do want env level services
if wkld, ok := resources[i].Tags[deploy.ServiceTagKey]; ok && wkld != h.wkld {
continue
}

arn, err := arn.Parse(resources[i].ARN)
if err != nil {
return nil, fmt.Errorf("invalid arn %q: %w", resources[i].ARN, err)
}

switch {
case strings.HasPrefix(arn.Resource, "db:"):
dbFilter.Values = append(dbFilter.Values, aws.String(resources[i].ARN))
case strings.HasPrefix(arn.Resource, "cluster:"):
clusterFilter.Values = append(clusterFilter.Values, aws.String(resources[i].ARN))
}
}

if len(dbFilter.Values) > 0 {
err = h.rds.DescribeDBInstancesPagesWithContext(ctx, &rds.DescribeDBInstancesInput{
Filters: []*rds.Filter{dbFilter},
}, func(out *rds.DescribeDBInstancesOutput, lastPage bool) bool {
for _, db := range out.DBInstances {
if db.Endpoint != nil {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(db.Endpoint.Address),
Port: uint16(aws.Int64Value(db.Endpoint.Port)),
})
}
}
return true
})
if err != nil {
return nil, fmt.Errorf("describe instances: %w", err)
}
}

if len(clusterFilter.Values) > 0 {
err = h.rds.DescribeDBClustersPagesWithContext(ctx, &rds.DescribeDBClustersInput{
Filters: []*rds.Filter{clusterFilter},
}, func(out *rds.DescribeDBClustersOutput, lastPage bool) bool {
for _, db := range out.DBClusters {
add := func(s *string) {
if s != nil {
hosts = append(hosts, orchestrator.Host{
Name: aws.StringValue(s),
Port: uint16(aws.Int64Value(db.Port)),
})
}
}

add(db.Endpoint)
add(db.ReaderEndpoint)
for i := range db.CustomEndpoints {
add(db.CustomEndpoints[i])
}
}
return true
})
if err != nil {
return nil, fmt.Errorf("describe clusters: %w", err)
}
}

return hosts, nil
}

Expand Down
Loading

0 comments on commit 8f3adbd

Please sign in to comment.