Skip to content

Commit

Permalink
add support for token file and eks container endpoint in general HTTP…
Browse files Browse the repository at this point in the history
… provider
  • Loading branch information
lucix-aws authored Nov 13, 2023
2 parents fe6c7bf + 0b32bd7 commit a07f333
Show file tree
Hide file tree
Showing 5 changed files with 231 additions and 67 deletions.
2 changes: 2 additions & 0 deletions CHANGELOG_PENDING.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,5 @@
### SDK Enhancements

### SDK Bugs
* `aws/defaults`: Feature updates to endpoint credentials provider.
* Add support for dynamic auth token from file and EKS container host in configured URI.
47 changes: 46 additions & 1 deletion aws/credentials/endpointcreds/provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ package endpointcreds

import (
"encoding/json"
"fmt"
"strings"
"time"

"github.com/aws/aws-sdk-go/aws"
Expand Down Expand Up @@ -69,7 +71,37 @@ type Provider struct {

// Optional authorization token value if set will be used as the value of
// the Authorization header of the endpoint credential request.
//
// When constructed from environment, the provider will use the value of
// AWS_CONTAINER_AUTHORIZATION_TOKEN environment variable as the token
//
// Will be overridden if AuthorizationTokenProvider is configured
AuthorizationToken string

// Optional auth provider func to dynamically load the auth token from a file
// everytime a credential is retrieved
//
// When constructed from environment, the provider will read and use the content
// of the file pointed to by AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE environment variable
// as the auth token everytime credentials are retrieved
//
// Will override AuthorizationToken if configured
AuthorizationTokenProvider AuthTokenProvider
}

// AuthTokenProvider defines an interface to dynamically load a value to be passed
// for the Authorization header of a credentials request.
type AuthTokenProvider interface {
GetToken() (string, error)
}

// TokenProviderFunc is a func type implementing AuthTokenProvider interface
// and enables customizing token provider behavior
type TokenProviderFunc func() (string, error)

// GetToken func retrieves auth token according to TokenProviderFunc implementation
func (p TokenProviderFunc) GetToken() (string, error) {
return p()
}

// NewProviderClient returns a credentials Provider for retrieving AWS credentials
Expand Down Expand Up @@ -164,7 +196,20 @@ func (p *Provider) getCredentials(ctx aws.Context) (*getCredentialsOutput, error
req := p.Client.NewRequest(op, nil, out)
req.SetContext(ctx)
req.HTTPRequest.Header.Set("Accept", "application/json")
if authToken := p.AuthorizationToken; len(authToken) != 0 {

authToken := p.AuthorizationToken
var err error
if p.AuthorizationTokenProvider != nil {
authToken, err = p.AuthorizationTokenProvider.GetToken()
if err != nil {
return nil, fmt.Errorf("get authorization token: %v", err)
}
}

if strings.ContainsAny(authToken, "\r\n") {
return nil, fmt.Errorf("authorization token contains invalid newline sequence")
}
if len(authToken) != 0 {
req.HTTPRequest.Header.Set("Authorization", authToken)
}

Expand Down
146 changes: 95 additions & 51 deletions aws/credentials/endpointcreds/provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -159,64 +159,108 @@ func TestFailedRetrieveCredentials(t *testing.T) {
}

func TestAuthorizationToken(t *testing.T) {
const expectAuthToken = "Basic abc123"
cases := map[string]struct {
ExpectPath string
ServerPath string
AuthToken string
AuthTokenProvider endpointcreds.AuthTokenProvider
ExpectAuthToken string
ExpectError bool
}{
"AuthToken": {
ExpectPath: "/path/to/endpoint",
ServerPath: "/path/to/endpoint?something=else",
AuthToken: "Basic abc123",
ExpectAuthToken: "Basic abc123",
},
"AuthFileToken": {
ExpectPath: "/path/to/endpoint",
ServerPath: "/path/to/endpoint?something=else",
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "Hello %20world", nil
}),
ExpectAuthToken: "Hello %20world",
},
"RetrieveFileTokenError": {
ExpectPath: "/path/to/endpoint",
ServerPath: "/path/to/endpoint?something=else",
AuthToken: "Basic abc123",
AuthTokenProvider: endpointcreds.TokenProviderFunc(func() (string, error) {
return "", fmt.Errorf("test error")
}),
ExpectAuthToken: "Hello %20world",
ExpectError: true,
},
}

server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := "/path/to/endpoint", r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := expectAuthToken, r.Header.Get("Authorization"); e != a {
t.Fatalf("expect %v, got %v", e, a)
}
for name, c := range cases {
t.Run(name, func(t *testing.T) {
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if e, a := c.ExpectPath, r.URL.Path; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "application/json", r.Header.Get("Accept"); e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := c.ExpectAuthToken, r.Header.Get("Authorization"); e != a {
t.Fatalf("expect %v, got %v", e, a)
}

encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})
encoder := json.NewEncoder(w)
err := encoder.Encode(map[string]interface{}{
"AccessKeyID": "AKID",
"SecretAccessKey": "SECRET",
"Token": "TOKEN",
"Expiration": time.Now().Add(1 * time.Hour),
})

if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
defer server.Close()
if err != nil {
fmt.Println("failed to write out creds", err)
}
}))
defer server.Close()

client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+"/path/to/endpoint?something=else",
func(p *endpointcreds.Provider) {
p.AuthorizationToken = expectAuthToken
},
)
creds, err := client.Retrieve()
client := endpointcreds.NewProviderClient(*unit.Session.Config,
unit.Session.Handlers,
server.URL+c.ServerPath,
func(p *endpointcreds.Provider) {
p.AuthorizationToken = c.AuthToken
p.AuthorizationTokenProvider = c.AuthTokenProvider
},
)
creds, err := client.Retrieve()

if err != nil {
t.Errorf("expect no error, got %v", err)
}
if err != nil && !c.ExpectError {
t.Errorf("expect no error, got %v", err)
} else if err == nil && c.ExpectError {
t.Errorf("expect error, got nil")
}

if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}
if c.ExpectError {
return
}

client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}
if e, a := "AKID", creds.AccessKeyID; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "SECRET", creds.SecretAccessKey; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if e, a := "TOKEN", creds.SessionToken; e != a {
t.Errorf("expect %v, got %v", e, a)
}
if client.IsExpired() {
t.Errorf("expect not expired, was")
}

if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
client.(*endpointcreds.Provider).CurrentTime = func() time.Time {
return time.Now().Add(2 * time.Hour)
}

if !client.IsExpired() {
t.Errorf("expect expired, wasn't")
}
})
}
}
64 changes: 54 additions & 10 deletions aws/defaults/defaults.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ package defaults

import (
"fmt"
"io/ioutil"
"net"
"net/http"
"net/url"
Expand Down Expand Up @@ -115,9 +116,31 @@ func CredProviders(cfg *aws.Config, handlers request.Handlers) []credentials.Pro

const (
httpProviderAuthorizationEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN"
httpProviderAuthFileEnvVar = "AWS_CONTAINER_AUTHORIZATION_TOKEN_FILE"
httpProviderEnvVar = "AWS_CONTAINER_CREDENTIALS_FULL_URI"
)

// direct representation of the IPv4 address for the ECS container
// "169.254.170.2"
var ecsContainerIPv4 net.IP = []byte{
169, 254, 170, 2,
}

// direct representation of the IPv4 address for the EKS container
// "169.254.170.23"
var eksContainerIPv4 net.IP = []byte{
169, 254, 170, 23,
}

// direct representation of the IPv6 address for the EKS container
// "fd00:ec2::23"
var eksContainerIPv6 net.IP = []byte{
0xFD, 0, 0xE, 0xC2,
0, 0, 0, 0,
0, 0, 0, 0,
0, 0, 0, 0x23,
}

// RemoteCredProvider returns a credentials provider for the default remote
// endpoints such as EC2 or ECS Roles.
func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.Provider {
Expand All @@ -135,26 +158,36 @@ func RemoteCredProvider(cfg aws.Config, handlers request.Handlers) credentials.P

var lookupHostFn = net.LookupHost

func isLoopbackHost(host string) (bool, error) {
ip := net.ParseIP(host)
if ip != nil {
return ip.IsLoopback(), nil
// isAllowedHost allows host to be loopback or known ECS/EKS container IPs
//
// host can either be an IP address OR an unresolved hostname - resolution will
// be automatically performed in the latter case
func isAllowedHost(host string) (bool, error) {
if ip := net.ParseIP(host); ip != nil {
return isIPAllowed(ip), nil
}

// Host is not an ip, perform lookup
addrs, err := lookupHostFn(host)
if err != nil {
return false, err
}

for _, addr := range addrs {
if !net.ParseIP(addr).IsLoopback() {
if ip := net.ParseIP(addr); ip == nil || !isIPAllowed(ip) {
return false, nil
}
}

return true, nil
}

func isIPAllowed(ip net.IP) bool {
return ip.IsLoopback() ||
ip.Equal(ecsContainerIPv4) ||
ip.Equal(eksContainerIPv4) ||
ip.Equal(eksContainerIPv6)
}

func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string) credentials.Provider {
var errMsg string

Expand All @@ -165,10 +198,12 @@ func localHTTPCredProvider(cfg aws.Config, handlers request.Handlers, u string)
host := aws.URLHostname(parsed)
if len(host) == 0 {
errMsg = "unable to parse host from local HTTP cred provider URL"
} else if isLoopback, loopbackErr := isLoopbackHost(host); loopbackErr != nil {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, loopbackErr)
} else if !isLoopback {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback hosts are allowed.", host)
} else if parsed.Scheme == "http" {
if isAllowedHost, allowHostErr := isAllowedHost(host); allowHostErr != nil {
errMsg = fmt.Sprintf("failed to resolve host %q, %v", host, allowHostErr)
} else if !isAllowedHost {
errMsg = fmt.Sprintf("invalid endpoint host, %q, only loopback/ecs/eks hosts are allowed.", host)
}
}
}

Expand All @@ -190,6 +225,15 @@ func httpCredProvider(cfg aws.Config, handlers request.Handlers, u string) crede
func(p *endpointcreds.Provider) {
p.ExpiryWindow = 5 * time.Minute
p.AuthorizationToken = os.Getenv(httpProviderAuthorizationEnvVar)
if authFilePath := os.Getenv(httpProviderAuthFileEnvVar); authFilePath != "" {
p.AuthorizationTokenProvider = endpointcreds.TokenProviderFunc(func() (string, error) {
if contents, err := ioutil.ReadFile(authFilePath); err != nil {
return "", fmt.Errorf("failed to read authorization token from %v: %v", authFilePath, err)
} else {
return string(contents), nil
}
})
}
},
)
}
Expand Down
Loading

0 comments on commit a07f333

Please sign in to comment.