diff --git a/changelog/fragments/1761377900-input-auth-method-aws.yaml b/changelog/fragments/1761377900-input-auth-method-aws.yaml new file mode 100644 index 000000000000..85503d2526f5 --- /dev/null +++ b/changelog/fragments/1761377900-input-auth-method-aws.yaml @@ -0,0 +1,45 @@ +# REQUIRED +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# REQUIRED for all kinds +# Change summary; a 80ish characters long description of the change. +summary: Add AWS auth method for CEL and HTTP JSON inputs. + +# REQUIRED for breaking-change, deprecation, known-issue +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# description: + +# REQUIRED for breaking-change, deprecation, known-issue +# impact: + +# REQUIRED for breaking-change, deprecation, known-issue +# action: + +# REQUIRED for all kinds +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: filebeat + +# AUTOMATED +# OPTIONAL to manually add other PR URLs +# PR URL: A link the PR that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +# pr: https://github.com/owner/repo/1234 + +# AUTOMATED +# OPTIONAL to manually add other issue URLs +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +# issue: https://github.com/owner/repo/1234 diff --git a/x-pack/filebeat/input/cel/config_auth.go b/x-pack/filebeat/input/cel/config_auth.go index 1620b31bb03c..a141f5380e82 100644 --- a/x-pack/filebeat/input/cel/config_auth.go +++ b/x-pack/filebeat/input/cel/config_auth.go @@ -21,13 +21,15 @@ import ( "golang.org/x/oauth2/google" "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" ) type authConfig struct { - Basic *basicAuthConfig `config:"basic"` - Token *tokenAuthConfig `config:"token"` - Digest *digestAuthConfig `config:"digest"` - OAuth2 *oAuth2Config `config:"oauth2"` + Basic *basicAuthConfig `config:"basic"` + Token *tokenAuthConfig `config:"token"` + Digest *digestAuthConfig `config:"digest"` + OAuth2 *oAuth2Config `config:"oauth2"` + AWS *aws.SignerInputConfig `config:"aws"` } func (c authConfig) Validate() error { @@ -44,6 +46,9 @@ func (c authConfig) Validate() error { if c.OAuth2.isEnabled() { n++ } + if c.AWS.IsEnabled() { + n++ + } if n > 1 { return errors.New("only one kind of auth can be enabled") } diff --git a/x-pack/filebeat/input/cel/input.go b/x-pack/filebeat/input/cel/input.go index 6c2346a86fac..76f7251e6530 100644 --- a/x-pack/filebeat/input/cel/input.go +++ b/x-pack/filebeat/input/cel/input.go @@ -51,6 +51,7 @@ import ( "github.com/elastic/beats/v7/libbeat/version" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httplog" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httpmon" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/elastic-agent-libs/monitoring" @@ -213,6 +214,7 @@ func (i input) run(env v2.Context, src *source, cursor map[string]interface{}, p Value: cfg.Auth.Token.Value, } } + wantDump := cfg.FailureDump.enabled() && cfg.FailureDump.Filename != "" doCov := cfg.RecordCoverage && log.IsDebug() httpOptions := lib.HTTPOptions{ @@ -864,6 +866,15 @@ func newClient(ctx context.Context, cfg config, log *logp.Logger, reg *monitorin Password: cfg.Auth.Digest.Password, NoReuse: noReuse, } + } else if cfg.Auth.AWS.IsEnabled() { + // this transport runs after the other ones (the other ones wrap this one); just to be on the safe side. + // If any of the other transports add any header, it must happen before the signing. + tr, err := aws.InitializeSignerTransport(*cfg.Auth.AWS, log, c.Transport) + if err != nil { + log.Errorw("failed to initialize aws config failed for signer", "error", err) + return nil, nil, err + } + c.Transport = tr } var trace *httplog.LoggingRoundTripper diff --git a/x-pack/filebeat/input/cel/input_test.go b/x-pack/filebeat/input/cel/input_test.go index afb7c1418b22..f91eaff52081 100644 --- a/x-pack/filebeat/input/cel/input_test.go +++ b/x-pack/filebeat/input/cel/input_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "sync" "testing" "time" @@ -1890,6 +1891,44 @@ var inputTests = []struct { }, }, + { + name: "Auth AWS V4 Signer", + server: func(t *testing.T, h http.HandlerFunc, config map[string]interface{}) { + s := httptest.NewServer(h) + config["resource.url"] = s.URL + t.Cleanup(s.Close) + }, + config: map[string]interface{}{ + "interval": 1, + "auth.aws.access_key_id": "AKIAIOSFODNN7EXAMPLE", + "auth.aws.secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "auth.aws.default_region": "us-east-1", + "auth.aws.service_name": "guardduty", + "program": ` + bytes(get(state.url).Body).as(body, { + "events": [body.decode_json()] + }) + `, + }, + handler: awsAuthHandler("AKIAIOSFODNN7EXAMPLE", defaultHandler(http.MethodGet, "")), + want: []map[string]interface{}{ + { + "hello": []interface{}{ + map[string]interface{}{ + "world": "moon", + }, + map[string]interface{}{ + "space": []interface{}{ + map[string]interface{}{ + "cake": "pumpkin", + }, + }, + }, + }, + }, + }, + }, + // Multi-step requests. { name: "simple_multistep_GET_request", @@ -2495,6 +2534,24 @@ func tokenAuthHandler(want string, handle http.HandlerFunc) http.HandlerFunc { } } +func awsAuthHandler(expectedTokenID string, handle http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/", expectedTokenID)) { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + amzDate := r.Header.Get("X-Amz-Date") + if amzDate == "" { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + handle(w, r) + } +} + //nolint:errcheck // No point checking errors in test server. func digestAuthHandler(user, pass, realm, nonce string, handle http.HandlerFunc) http.HandlerFunc { chal := &digest.Challenge{ diff --git a/x-pack/filebeat/input/httpjson/config_auth.go b/x-pack/filebeat/input/httpjson/config_auth.go index 4021d0c29e88..0b68f4399403 100644 --- a/x-pack/filebeat/input/httpjson/config_auth.go +++ b/x-pack/filebeat/input/httpjson/config_auth.go @@ -20,15 +20,27 @@ import ( "golang.org/x/oauth2/google" "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" ) type authConfig struct { - Basic *basicAuthConfig `config:"basic"` - OAuth2 *oAuth2Config `config:"oauth2"` + Basic *basicAuthConfig `config:"basic"` + OAuth2 *oAuth2Config `config:"oauth2"` + AWS *aws.SignerInputConfig `config:"aws"` } func (c authConfig) Validate() error { - if c.Basic.isEnabled() && c.OAuth2.isEnabled() { + var n int + if c.Basic.isEnabled() { + n++ + } + if c.OAuth2.isEnabled() { + n++ + } + if c.AWS.IsEnabled() { + n++ + } + if n > 1 { return errors.New("only one kind of auth can be enabled") } return nil diff --git a/x-pack/filebeat/input/httpjson/input.go b/x-pack/filebeat/input/httpjson/input.go index 7809421ad1e6..cf7bd7ccf306 100644 --- a/x-pack/filebeat/input/httpjson/input.go +++ b/x-pack/filebeat/input/httpjson/input.go @@ -38,6 +38,7 @@ import ( "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httplog" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httpmon" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/private" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/elastic-agent-libs/monitoring" @@ -304,7 +305,22 @@ func newHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *request client *http.Client err error ) - if authCfg.OAuth2.isEnabled() { + switch { + case authCfg.AWS.IsEnabled(): + client, err = newNetHTTPClient(ctx, requestCfg, log, reg) + if err != nil { + log.Errorw("creation of initial http client failed", "error", err) + return nil, err + } + + log.Debugw("creating signer", "region", authCfg.AWS.DefaultRegion, "service", authCfg.AWS.ServiceName) + tr, err := aws.InitializeSignerTransport(*authCfg.AWS, log, client.Transport) + if err != nil { + log.Errorw("failed to initialize aws config failed for signer", "error", err) + return nil, err + } + client.Transport = tr + case authCfg.OAuth2.isEnabled(): client = authCfg.OAuth2.prepared if client == nil { client, err = newNetHTTPClient(ctx, requestCfg, log, reg) @@ -317,7 +333,7 @@ func newHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *request } authCfg.OAuth2.prepared = client } - } else { + default: client, err = newNetHTTPClient(ctx, requestCfg, log, reg) if err != nil { return nil, err diff --git a/x-pack/filebeat/input/httpjson/input_test.go b/x-pack/filebeat/input/httpjson/input_test.go index 7a18b704d36a..69bfe67ad0ff 100644 --- a/x-pack/filebeat/input/httpjson/input_test.go +++ b/x-pack/filebeat/input/httpjson/input_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" @@ -653,6 +654,24 @@ var testCases = []struct { handler: oauth2Handler, expected: []string{`{"hello": "world"}`}, }, + { + name: "aws auth", + setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { + server := httptest.NewServer(h) + config["request.url"] = server.URL + t.Cleanup(server.Close) + }, + baseConfig: map[string]interface{}{ + "interval": 1, + "request.method": http.MethodGet, + "auth.aws.access_key_id": "AKIAIOSFODNN7EXAMPLE", + "auth.aws.secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "auth.aws.default_region": "us-east-1", + "auth.aws.service_name": "guardduty", + }, + handler: awsAuthHandler("AKIAIOSFODNN7EXAMPLE", defaultHandler(http.MethodGet, "", "")), + expected: []string{`{"hello":[{"world":"moon"},{"space":[{"cake":"pumpkin"}]}]}`}, + }, { name: "request_transforms_can_access_state_from_previous_transforms", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { @@ -1889,6 +1908,24 @@ func defaultHandler(expectedMethod, expectedBody, msg string) http.HandlerFunc { } } +func awsAuthHandler(expectedTokenID string, handle http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/", expectedTokenID)) { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + amzDate := r.Header.Get("X-Amz-Date") + if amzDate == "" { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + handle(w, r) + } +} + func rateLimitHandler() http.HandlerFunc { var isRetry bool return func(w http.ResponseWriter, r *http.Request) { diff --git a/x-pack/libbeat/common/aws/signer.go b/x-pack/libbeat/common/aws/signer.go new file mode 100644 index 000000000000..4f8bb79e1efe --- /dev/null +++ b/x-pack/libbeat/common/aws/signer.go @@ -0,0 +1,207 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awslogging "github.com/aws/smithy-go/logging" + + "github.com/elastic/elastic-agent-libs/logp" +) + +// SignerInputConfig is the top-level configuration for the input aws auth method, +// used both in CEL and HTTP JSON inputs. It wraps the [ConfigAWS] inlined. +type SignerInputConfig struct { + // Enabled indicates whether this auth method is used. + // [SignerInputConfig.IsEnabled] implements the logic of the check. + Enabled *bool `config:"enabled"` + + // ServiceName can be used to optionally set the AWS service name that AWS V4 signer will sign for. + // If not value is set here, the signer will try to infer the service from the url of the request. + ServiceName string `config:"service_name"` + + // ConfigAWS is the inline wrapping of the rest of the [ConfigAWS] that can be use in the input config. + ConfigAWS `config:",inline"` +} + +// IsEnabled returns true if the `enable` field is set to true in the yaml or if it is nil. +func (c *SignerInputConfig) IsEnabled() bool { + return c != nil && (c.Enabled == nil || *c.Enabled) +} + +// SignerTransport implements [http.RoundTripper] interface +// and signs requests with aws v4 signer before send them to the next roundtripper. +// If the `serviceName` and `region` are not set, the signer will try to infer them from each request's URL. +type SignerTransport struct { + next http.RoundTripper + credentials aws.CredentialsProvider + signer *v4.Signer + logger *logp.Logger + serviceName string + region string + now func() time.Time // we don't use [time.Now] directly, so we can mock time in tests. +} + +// InitializeSignerTransport initializes first the AWS config using the [InitializeAWSConfig] and the [ConfigAWS] and then, +// initializes the RoundTripper using the AWS credentials from the previously initialized AWS config. +func InitializeSignerTransport(cfg SignerInputConfig, logger *logp.Logger, nextTransport http.RoundTripper) (*SignerTransport, error) { + awsConfig, err := InitializeAWSConfig(cfg.ConfigAWS, logger) + if err != nil { + return nil, err + } + + return initializeSignerTransport(logger, cfg.ServiceName, cfg.DefaultRegion, awsConfig.Credentials, nextTransport), nil +} + +func initializeSignerTransport(logger *logp.Logger, defaultServiceName string, defaultRegion string, credentials aws.CredentialsProvider, nextTransport http.RoundTripper) *SignerTransport { + return &SignerTransport{ + next: nextTransport, + credentials: credentials, + signer: v4.NewSigner(func(signer *v4.SignerOptions) { + signer.Logger = awslogging.LoggerFunc(func(classification awslogging.Classification, format string, v ...any) { + switch classification { + case awslogging.Debug: + logger.Debugf(format, v...) + case awslogging.Warn: + logger.Warnf(format, v...) + } + }) + }), + logger: logger, + serviceName: defaultRegion, + region: defaultServiceName, + now: time.Now, + } +} + +func (st *SignerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // resolve service name and region (if they are not configured) + serviceName, region, err := st.getServiceAndRegion(req) + if err != nil { + return nil, fmt.Errorf("error while getting service name and region: %w", err) + } + + // retrieve credentials + creds, err := st.credentials.Retrieve(req.Context()) + if err != nil { + return nil, fmt.Errorf("error while retrieving credentials: %w", err) + } + + // body hash + payloadHash, err := st.bodySHA256Hash(req) + if err != nil { + return nil, fmt.Errorf("error while calculating body hash: %w", err) + } + + // sign the request + err = st.signer.SignHTTP(req.Context(), creds, req, payloadHash, serviceName, region, st.now()) + if err != nil { + return nil, fmt.Errorf("error while signing the request: %w", err) + } + + // next transport + return st.next.RoundTrip(req) +} + +// bodySHA256Hash returns the sha256 hash of the request's body by reading a copy of the body. +// The request's Body remains readable and unmodified after this function returns. +func (st *SignerTransport) bodySHA256Hash(req *http.Request) (string, error) { + if req.Body == nil || req.Body == http.NoBody { + return hex.EncodeToString(sha256.New().Sum(nil)), nil + } + + // this is a copy of the original body + body, err := st.getBody(req) + if err != nil { + return "", err + } + + hash := sha256.New() + + if _, err := io.Copy(hash, body); err != nil { + return "", err + } + + return hex.EncodeToString(hash.Sum(nil)), body.Close() +} + +// getBody returns a copy of the request's body as a [io.ReadCloser]. +// The request's Body remains readable and unmodified after this function returns. +func (st *SignerTransport) getBody(req *http.Request) (io.ReadCloser, error) { + if req.GetBody != nil { + // [http.Request] GetBody dictates that a new copy of the body must be returned. + return req.GetBody() + } + + if req.Body == http.NoBody || req.Body == nil { + return req.Body, nil + } + + // If the GetBody does not exist we need to manually copy the body. + // In Beats use-case its not possible for this to happen, + // since, both in cel and httpjson, the request is initialized + // with *bytes.Buffer, *bytes.Reader or *strings.Reader as body, which gets GetBody initialized. + // httpjson: (x-pack/filebeat/input/httpjson/request.go newHTTPRequest) + // cel: mito repo (lib/http.go) + // We cover the edge case here by reading and copying the body manually. + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("error while reading request body: %w", err) + } + if err := req.Body.Close(); err != nil { + st.logger.Warnf("error while closing copied body %s", err.Error()) + } + + // reset body to the request + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + return io.NopCloser(bytes.NewBuffer(bodyBytes)), nil +} + +// getServiceAndRegion returns the service name and the region for the upcoming request. +// If service name and region are configured with default values, those take precedence. +// Otherwise it will try to parse the values from [http.Request] Host value. +func (st *SignerTransport) getServiceAndRegion(req *http.Request) (serviceName, region string, err error) { + serviceName = st.serviceName + region = st.region + + if serviceName == "" || region == "" { + s, r, err := parseServiceAndRegionFromHost(req.Host) + if err != nil { + return "", "", err + } + if serviceName == "" { + serviceName = s + } + if region == "" { + region = r + } + } + + return serviceName, region, nil +} + +func parseServiceAndRegionFromHost(host string) (service, region string, err error) { + parts := strings.SplitN(host, ".", 4) + + if len(parts) < 4 { + return "", "", errMalformedHost + } + + return parts[0], parts[1], nil +} + +var errMalformedHost = errors.New("malformed host string") diff --git a/x-pack/libbeat/common/aws/signer_test.go b/x-pack/libbeat/common/aws/signer_test.go new file mode 100644 index 000000000000..405e630c1834 --- /dev/null +++ b/x-pack/libbeat/common/aws/signer_test.go @@ -0,0 +1,296 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "bytes" + "io" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/logp/logptest" +) + +func mockNow(v time.Time) func() time.Time { return func() time.Time { return v } } + +type mockRoundTripper struct { + mock.Mock + req *http.Request +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + m.req = req // store the request for later assertions. + return args.Get(0).(*http.Response), args.Error(1) //nolint:errcheck // not needed here. +} + +func TestSignerTransportRoundTrip(t *testing.T) { + now := mockNow(time.Date(2025, time.October, 11, 16, 0, 0, 0, time.UTC)) + + // fake credentials received from this: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html + fakeStaticCreds := credentials.NewStaticCredentialsProvider("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "session") + + tests := []struct { + name string + defaultServiceName string + defaultRegion string + url string + requestBody io.Reader + requestHeaders map[string]string + credentials aws.CredentialsProvider + now func() time.Time + initMockRoundTripper func(*mockRoundTripper) + expectError bool + expectedRequestHeaders map[string]string + expectedRequestBody []byte + }{ + { + name: "no body", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: http.NoBody, + requestHeaders: map[string]string{}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=a73ff41e90b3e54c8855dc53cb352c244f4cf39122838e4ded22eef0fde01095", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + }, + expectedRequestBody: []byte{}, + }, + { + name: "with body", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: bytes.NewBuffer([]byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`)), + requestHeaders: map[string]string{}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=content-length;host;x-amz-date;x-amz-security-token, Signature=1cba6843418733071843e982a5e399eebfa3caeef3bae336ab4477abf42a9fb7", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + }, + expectedRequestBody: []byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`), + }, + { + name: "with body and headers", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: bytes.NewBuffer([]byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`)), + requestHeaders: map[string]string{"X-Extra-Header": "abc123"}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=content-length;host;x-amz-date;x-amz-security-token;x-extra-header, Signature=a9ae9766395c5749fca156baf9c65ef78d4d3053866299db48838c3546aaeb25", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + "X-Extra-Header": "abc123", + }, + expectedRequestBody: []byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + + m := mockRoundTripper{} + if tc.initMockRoundTripper != nil { + tc.initMockRoundTripper(&m) + } + + st := initializeSignerTransport(logger, tc.defaultServiceName, tc.defaultRegion, tc.credentials, &m) + st.now = tc.now + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, tc.url, tc.requestBody) + require.NoError(t, err) + for k, v := range tc.requestHeaders { + req.Header.Set(k, v) + } + + _, err = st.RoundTrip(req) //nolint:bodyclose // we don't actually have response body here + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, err) + + gotHeaders := map[string]string{} + for k := range m.req.Header { + gotHeaders[k] = m.req.Header.Get(k) + } + assert.EqualValues(t, tc.expectedRequestHeaders, gotHeaders) + + // ensure that request's body is readable (and not consumed) after the hash operation. + b, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, tc.expectedRequestBody, b) + }) + } +} + +func TestBodySHA256Hash(t *testing.T) { + tests := []struct { + name string + body io.Reader + expectedHash string + expectError bool + }{ + { + name: "no body", + body: http.NoBody, + expectedHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectError: false, + }, + { + name: "with body that initializes GetBody", + body: bytes.NewReader([]byte(`"abc"`)), + expectedHash: "6cc43f858fbb763301637b5af970e2a46b46f461f27e5a0f41e009c59b827b25", + expectError: false, + }, + { + name: "with body without initialized GetBody", + body: io.NopCloser(bytes.NewReader([]byte(`"abc"`))), + expectedHash: "6cc43f858fbb763301637b5af970e2a46b46f461f27e5a0f41e009c59b827b25", + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + st := SignerTransport{ + next: nil, + credentials: nil, + signer: nil, + logger: logger, + serviceName: "", + region: "", + now: time.Now, + } + + req, err := http.NewRequestWithContext(t.Context(), "GET", "sample.amazonaws.com", tc.body) + require.NoError(t, err) + + gotHash, gotErr := st.bodySHA256Hash(req) + + assert.Equal(t, tc.expectedHash, gotHash, "hash of body is different than expected") + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, gotErr) + }) + } +} + +func TestGetServiceAndRegion(t *testing.T) { + tests := []struct { + name string + configuredServiceName string + configuredRegion string + requestHost string + expectedServiceName string + expectedRegion string + expectError bool + }{ + { + name: "extract from host", + configuredServiceName: "", + configuredRegion: "", + requestHost: "guardduty.us-east-1.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "configured values take precedence", + configuredServiceName: "guardduty", + configuredRegion: "us-east-1", + requestHost: "abc.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "service name configured region from url", + configuredServiceName: "guardduty", + configuredRegion: "", + requestHost: "abc.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-2", + expectError: false, + }, + { + name: "service name from url region configured", + configuredServiceName: "", + configuredRegion: "us-east-1", + requestHost: "guardduty.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "malformed host", + configuredServiceName: "", + configuredRegion: "", + requestHost: "amazonaws.com", + expectedServiceName: "", + expectedRegion: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + st := SignerTransport{ + next: nil, + credentials: nil, + signer: nil, + logger: logger, + serviceName: tc.configuredServiceName, + region: tc.configuredRegion, + now: time.Now, + } + + req := &http.Request{Host: tc.requestHost} + + gotServiceName, gotRegion, gotErr := st.getServiceAndRegion(req) + + assert.Equal(t, tc.expectedServiceName, gotServiceName, "service name is different than expected") + assert.Equal(t, tc.expectedRegion, gotRegion, "service name is different than expected") + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, gotErr) + }) + } +}