Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 42 additions & 0 deletions lib/configurators/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"context"
"errors"
"fmt"
"io"
"os"
"slices"
"strings"

Expand All @@ -38,6 +40,7 @@ import (
apiutils "github.com/gravitational/teleport/api/utils"
apiawsutils "github.com/gravitational/teleport/api/utils/aws"
awslib "github.com/gravitational/teleport/lib/cloud/aws"
awsimds "github.com/gravitational/teleport/lib/cloud/imds/aws"
"github.com/gravitational/teleport/lib/configurators"
"github.com/gravitational/teleport/lib/defaults"
"github.com/gravitational/teleport/lib/modules"
Expand Down Expand Up @@ -316,6 +319,41 @@ type ssmClient interface {
CreateDocument(ctx context.Context, params *ssm.CreateDocumentInput, optFns ...func(*ssm.Options)) (*ssm.CreateDocumentOutput, error)
}

type localRegionGetter interface {
GetRegion(context.Context) (string, error)
}

func getLocalRegion(ctx context.Context, localRegionGetter localRegionGetter) (string, bool) {
if localRegionGetter == nil {
imdsClient, err := awsimds.NewInstanceMetadataClient(ctx)
if err != nil || !imdsClient.IsAvailable(ctx) {
return "", false
}
localRegionGetter = imdsClient
}

region, err := localRegionGetter.GetRegion(ctx)
if err != nil || region == "" {
return "", false
}
return region, true
}

func getFallbackRegion(ctx context.Context, w io.Writer, localRegionGetter localRegionGetter) string {
if localRegion, ok := getLocalRegion(ctx, localRegionGetter); ok {
fmt.Fprintf(w, "Using region %q from instance metadata.\n", localRegion)
return localRegion
}

// Fallback to us-east-1, which also supports fips.
fmt.Fprint(w, `
Warning: No region found from the default AWS config or instance metadata. Defaulting to 'us-east-1'.
To avoid seeing this warning, please provide a region in your AWS config or through the AWS_REGION environment variable.

`)
return "us-east-1"
}

// CheckAndSetDefaults checks and set configuration default values.
func (c *ConfiguratorConfig) CheckAndSetDefaults() error {
ctx := context.Background()
Expand All @@ -342,6 +380,10 @@ func (c *ConfiguratorConfig) CheckAndSetDefaults() error {
if err != nil {
return trace.Wrap(err)
}

if cfg.Region == "" {
cfg.Region = getFallbackRegion(ctx, os.Stdout, nil)
}
c.awsCfg = &cfg
}

Expand Down
40 changes: 40 additions & 0 deletions lib/configurators/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ package aws
import (
"context"
"fmt"
"io"
"regexp"
"sort"
"testing"
Expand Down Expand Up @@ -1918,3 +1919,42 @@ func (m *iamMock) GetRole(ctx context.Context, input *iam.GetRoleInput, optFns .
arn := fmt.Sprintf("arn:%s:iam::%s:role%s%s", m.partition, m.account, path, roleName)
return &iam.GetRoleOutput{Role: &iamtypes.Role{Arn: &arn}}, nil
}

type mockLocalRegionGetter struct {
region string
err error
}

func (m mockLocalRegionGetter) GetRegion(context.Context) (string, error) {
return m.region, m.err
}

func Test_getFallbackRegion(t *testing.T) {
tests := []struct {
name string
localRegionGetter localRegionGetter
wantRegion string
}{
{
name: "fallback to retrieved local region",
localRegionGetter: mockLocalRegionGetter{
region: "my-local-region",
},
wantRegion: "my-local-region",
},
{
name: "fallback to us-east",
localRegionGetter: mockLocalRegionGetter{
err: fmt.Errorf("failed to get local region"),
},
wantRegion: "us-east-1",
},
}

for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
region := getFallbackRegion(context.Background(), io.Discard, test.localRegionGetter)
require.Equal(t, test.wantRegion, region)
})
}
}