Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions .github/ISSUE_TEMPLATE/testplan.md
Original file line number Diff line number Diff line change
Expand Up @@ -1041,9 +1041,10 @@ tsh bench web sessions --max=5000 --web user ls
- [ ] Verify [AWS console access](https://goteleport.com/docs/application-access/cloud-apis/aws-console/).
- [ ] Can log into AWS web console through the web UI.
- [ ] Can interact with AWS using `tsh` commands.
- [ ] `tsh aws`
- [ ] `tsh aws --endpoint-url` (this is a hidden flag)
- [ ] Verify [Azure CLI access](https://goteleport.com/docs/application-access/cloud-apis/azure/) with `tsh apps login`.
- [ ] `tsh aws sts get-caller-identity`
- [ ] `tsh aws s3 ls`
- [ ] `tsh aws s3 cp ./file s3://<bucket>/test`
- [ ] Verify [Azure CLI access](https://goteleport.com/docs/enroll-resources/application-access/cloud-apis/azure/) with `tsh apps login`.
- [ ] Can interact with Azure using `tsh az` commands.
- [ ] Can interact with Azure using a combination of `tsh proxy az` and `az` commands.
- [ ] Verify [GCP CLI access](https://goteleport.com/docs/application-access/cloud-apis/google-cloud/) with `tsh apps login`.
Expand Down
4 changes: 3 additions & 1 deletion lib/srv/alpnproxy/aws_local_proxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,9 @@ func TestAWSAccessMiddleware(t *testing.T) {
t.Run("request with body", func(t *testing.T) {
body := []byte("body")
req := httptest.NewRequest(http.MethodPost, "http://sts.us-east-2.amazonaws.com", bytes.NewReader(body))
awsutils.NewSigner("sts").SignHTTP(t.Context(), localCred, req, awsutils.GetV4PayloadHash(body), "sts", "us-east-2", time.Now())
payloadHash, err := awsutils.GetV4PayloadHash(req)
require.NoError(t, err)
awsutils.NewSigner("sts").SignHTTP(t.Context(), localCred, req, payloadHash, "sts", "us-east-2", time.Now())

recorder := httptest.NewRecorder()
require.False(t, m.HandleRequest(recorder, req))
Expand Down
24 changes: 17 additions & 7 deletions lib/utils/aws/aws.go
Original file line number Diff line number Diff line change
Expand Up @@ -192,9 +192,7 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider)
}
}

// Read the request body and replace the body ready with a new reader that will allow reading the body again
// by HTTP Transport.
payload, err := utils.GetAndReplaceRequestBody(req)
payloadHash, err := GetV4PayloadHash(req)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -229,7 +227,7 @@ func VerifyAWSSignature(req *http.Request, credProvider aws.CredentialsProvider)
}

signer := NewSigner(sigV4.Service)
err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), sigV4.Service, sigV4.Region, t)
err = signer.SignHTTP(ctx, creds, reqCopy, payloadHash, sigV4.Service, sigV4.Region, t)
if err != nil {
return trace.Wrap(err)
}
Expand Down Expand Up @@ -261,14 +259,26 @@ func NewSigner(signingServiceName string) *v4.Signer {
}

// GetV4PayloadHash returns the V4 signing payload hash.
func GetV4PayloadHash(payload []byte) string {
func GetV4PayloadHash(req *http.Request) (string, error) {
payloadHash := strings.ToUpper(req.Header.Get("x-amz-content-sha256"))
switch payloadHash {
// unsigned payload, so we use the literal content string instead of hashing
// https://docs.aws.amazon.com/AmazonS3/latest/API/sigv4-auth-using-authorization-header.html
case "UNSIGNED-PAYLOAD", "STREAMING-UNSIGNED-PAYLOAD-TRAILER":
return payloadHash, nil
default:
}
payload, err := utils.GetAndReplaceRequestBody(req)
if err != nil {
return "", trace.Wrap(err)
}
if len(payload) == 0 {
return EmptyPayloadHash
return EmptyPayloadHash, nil
}

hash := sha256.New()
hash.Write(payload)
return hex.EncodeToString(hash.Sum(nil))
return hex.EncodeToString(hash.Sum(nil)), nil
}

// filterHeaders removes request headers that are not in the headers list and returns the removed header keys.
Expand Down
110 changes: 110 additions & 0 deletions lib/utils/aws/aws_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,21 @@
package aws

import (
"crypto/tls"
"io"
"net/http"
"net/http/httptest"
"strings"
"testing"

"github.com/aws/aws-sdk-go-v2/aws"
"github.com/aws/aws-sdk-go-v2/credentials"
"github.com/aws/aws-sdk-go-v2/service/s3"
s3types "github.com/aws/aws-sdk-go-v2/service/s3/types"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/lib/utils"
)

// TestExtractCredFromAuthHeader test the extractCredFromAuthHeader function logic.
Expand Down Expand Up @@ -516,3 +528,101 @@ func TestResourceARN(t *testing.T) {
})
}
}

func TestVerifyAWSSignature(t *testing.T) {
creds1 := credentials.NewStaticCredentialsProvider("sameid", "secret1", "")
creds2 := credentials.NewStaticCredentialsProvider("sameid", "secret2", "")
creds3 := credentials.NewStaticCredentialsProvider("otherid", "secret1", "")
tests := []struct {
desc string
clientCreds aws.CredentialsProvider
serverCreds aws.CredentialsProvider
checksumAlgorithm s3types.ChecksumAlgorithm
noTLS bool
wantSha256Header string
wantError string
}{
{
desc: "unsigned payload",
clientCreds: creds1,
serverCreds: creds1,
wantSha256Header: "UNSIGNED-PAYLOAD",
},
{
desc: "signed payload",
clientCreds: creds1,
serverCreds: creds1,
// echo -n 'llama' | sha256sum
wantSha256Header: "fc5a1047f5919892fcdf8aa79ea5d6bb6531b5c176939ef0110906cb225941c1",
noTLS: true,
},
{
desc: "streaming unsigned payload trailer",
clientCreds: creds1,
serverCreds: creds1,
checksumAlgorithm: s3types.ChecksumAlgorithmCrc32,
wantSha256Header: "STREAMING-UNSIGNED-PAYLOAD-TRAILER",
},
{
desc: "different credential ID",
clientCreds: creds3,
serverCreds: creds1,
wantSha256Header: "UNSIGNED-PAYLOAD",
wantError: "AccessKeyID does not match",
},
{
desc: "different credential secret",
clientCreds: creds2,
serverCreds: creds1,
wantSha256Header: "UNSIGNED-PAYLOAD",
wantError: "signature verification failed",
},
}
for _, test := range tests {
t.Run(test.desc, func(t *testing.T) {
srv := httptest.NewUnstartedServer(http.HandlerFunc(func(w http.ResponseWriter, req *http.Request) {
assert.Equal(t, test.wantSha256Header, req.Header.Get("x-amz-content-sha256"))
bodyBefore, err := utils.GetAndReplaceRequestBody(req)
assert.NoError(t, err)
err = VerifyAWSSignature(req, test.serverCreds)
if test.wantError != "" {
assert.ErrorContains(t, err, test.wantError)
} else {
assert.NoError(t, err)
}
bodyAfter, err := io.ReadAll(req.Body)
assert.NoError(t, err)
assert.Equal(t, string(bodyBefore), string(bodyAfter),
"checking a signature should not modify the request contents",
)
w.WriteHeader(http.StatusOK)
}))
defer srv.Close()
if test.noTLS {
srv.Start()
} else {
srv.StartTLS()
}
tr := http.DefaultTransport.(*http.Transport).Clone()
tr.TLSClientConfig = &tls.Config{
InsecureSkipVerify: true,
}
clt := s3.New(s3.Options{
Credentials: test.clientCreds,
BaseEndpoint: aws.String(srv.URL),
Region: "us-west-2",
RetryMaxAttempts: 1,
HTTPClient: &http.Client{
Transport: tr,
},
})
_, err := clt.PutObject(t.Context(), &s3.PutObjectInput{
ChecksumAlgorithm: test.checksumAlgorithm,
Bucket: aws.String("bucket"),
Key: aws.String("key"),
Body: strings.NewReader("llama"),
})
require.NoError(t, err)
})
}
}
4 changes: 2 additions & 2 deletions lib/utils/aws/signing.go
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*
if err := signCtx.Check(); err != nil {
return nil, trace.Wrap(err)
}
payload, err := utils.GetAndReplaceRequestBody(req)
payloadHash, err := GetV4PayloadHash(req)
if err != nil {
return nil, trace.Wrap(err)
}
Expand All @@ -101,7 +101,7 @@ func SignRequest(ctx context.Context, req *http.Request, signCtx *SigningCtx) (*
}

signer := NewSigner(signCtx.SigningName)
err = signer.SignHTTP(ctx, creds, reqCopy, GetV4PayloadHash(payload), signCtx.SigningName, signCtx.SigningRegion, time.Now())
err = signer.SignHTTP(ctx, creds, reqCopy, payloadHash, signCtx.SigningName, signCtx.SigningRegion, time.Now())
if err != nil {
return nil, trace.Wrap(err)
}
Expand Down
Loading