diff --git a/changelog/fragments/1761377900-input-auth-method-aws.yaml b/changelog/fragments/1761377900-input-auth-method-aws.yaml new file mode 100644 index 000000000000..85503d2526f5 --- /dev/null +++ b/changelog/fragments/1761377900-input-auth-method-aws.yaml @@ -0,0 +1,45 @@ +# REQUIRED +# Kind can be one of: +# - breaking-change: a change to previously-documented behavior +# - deprecation: functionality that is being removed in a later release +# - bug-fix: fixes a problem in a previous version +# - enhancement: extends functionality but does not break or fix existing behavior +# - feature: new functionality +# - known-issue: problems that we are aware of in a given version +# - security: impacts on the security of a product or a user’s deployment. +# - upgrade: important information for someone upgrading from a prior version +# - other: does not fit into any of the other categories +kind: feature + +# REQUIRED for all kinds +# Change summary; a 80ish characters long description of the change. +summary: Add AWS auth method for CEL and HTTP JSON inputs. + +# REQUIRED for breaking-change, deprecation, known-issue +# Long description; in case the summary is not enough to describe the change +# this field accommodate a description without length limits. +# description: + +# REQUIRED for breaking-change, deprecation, known-issue +# impact: + +# REQUIRED for breaking-change, deprecation, known-issue +# action: + +# REQUIRED for all kinds +# Affected component; usually one of "elastic-agent", "fleet-server", "filebeat", "metricbeat", "auditbeat", "all", etc. +component: filebeat + +# AUTOMATED +# OPTIONAL to manually add other PR URLs +# PR URL: A link the PR that added the changeset. +# If not present is automatically filled by the tooling finding the PR where this changelog fragment has been added. +# NOTE: the tooling supports backports, so it's able to fill the original PR number instead of the backport PR number. +# Please provide it if you are adding a fragment for a different PR. +# pr: https://github.com/owner/repo/1234 + +# AUTOMATED +# OPTIONAL to manually add other issue URLs +# Issue URL; optional; the GitHub issue related to this changeset (either closes or is part of). +# If not present is automatically filled by the tooling with the issue linked to the PR number. +# issue: https://github.com/owner/repo/1234 diff --git a/x-pack/filebeat/input/cel/config_auth.go b/x-pack/filebeat/input/cel/config_auth.go index 1620b31bb03c..57f568f729d1 100644 --- a/x-pack/filebeat/input/cel/config_auth.go +++ b/x-pack/filebeat/input/cel/config_auth.go @@ -21,13 +21,15 @@ import ( "golang.org/x/oauth2/google" "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" ) type authConfig struct { - Basic *basicAuthConfig `config:"basic"` - Token *tokenAuthConfig `config:"token"` - Digest *digestAuthConfig `config:"digest"` - OAuth2 *oAuth2Config `config:"oauth2"` + Basic *basicAuthConfig `config:"basic"` + Token *tokenAuthConfig `config:"token"` + Digest *digestAuthConfig `config:"digest"` + OAuth2 *oAuth2Config `config:"oauth2"` + AWS *aws.SignerInputConfig `config:"aws"` } func (c authConfig) Validate() error { @@ -44,6 +46,9 @@ func (c authConfig) Validate() error { if c.OAuth2.isEnabled() { n++ } + if c.AWS.IsEnabled() { + n++ + } if n > 1 { return errors.New("only one kind of auth can be enabled") } @@ -229,7 +234,7 @@ func (o *oAuth2Config) client(ctx context.Context, client *http.Client) (*http.C var creds *google.Credentials var err error if len(o.GoogleCredentialsJSON) != 0 { - creds, err = google.CredentialsFromJSON(ctx, o.GoogleCredentialsJSON, o.Scopes...) + creds, err = google.CredentialsFromJSON(ctx, o.GoogleCredentialsJSON, o.Scopes...) //nolint:staticcheck // deprecated but no suitable replacement available if err != nil { return nil, fmt.Errorf("oauth2 client: error loading credentials: %w", err) } diff --git a/x-pack/filebeat/input/cel/input.go b/x-pack/filebeat/input/cel/input.go index 6c2346a86fac..8264f8cfd7a3 100644 --- a/x-pack/filebeat/input/cel/input.go +++ b/x-pack/filebeat/input/cel/input.go @@ -51,6 +51,7 @@ import ( "github.com/elastic/beats/v7/libbeat/version" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httplog" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httpmon" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/elastic-agent-libs/monitoring" @@ -100,7 +101,7 @@ func (i input) now() time.Time { func (input) Name() string { return inputName } func (input) Test(src inputcursor.Source, _ v2.TestContext) error { - cfg := src.(*source).cfg + cfg := src.(*source).cfg //nolint:errcheck // src is always *source in this input implementation if !wantClient(cfg) { return nil } @@ -110,7 +111,7 @@ func (input) Test(src inputcursor.Source, _ v2.TestContext) error { // Run starts the input and blocks until it ends completes. It will return on // context cancellation or type invalidity errors, any other error will be retried. func (input) Run(env v2.Context, src inputcursor.Source, crsr inputcursor.Cursor, pub inputcursor.Publisher) error { - dataStreamName := src.(*source).cfg.DataStream // May be empty. + dataStreamName := src.(*source).cfg.DataStream //nolint:errcheck // src is always *source in this input implementation var cursor map[string]interface{} env.UpdateStatus(status.Starting, dataStreamName) @@ -128,7 +129,7 @@ func (input) Run(env v2.Context, src inputcursor.Source, crsr inputcursor.Cursor parent: &env, } } - err := input{}.run(env, src.(*source), cursor, pub, health) + err := input{}.run(env, src.(*source), cursor, pub, health) //nolint:errcheck // src is always *source in this input implementation if err != nil { msg := "failed to run: " + err.Error() if dataStreamName != "" { @@ -182,7 +183,7 @@ func (i input) run(env v2.Context, src *source, cursor map[string]interface{}, p return err } if !ok { - return fmt.Errorf("request tracer path %q must be within %q path", path, paths.Resolve(paths.Logs, inputName)) + return fmt.Errorf("request tracer path %q must be within %q path", path, paths.Resolve(paths.Logs, inputName)) //nolint:forbidigo // no per-beat path instance available here } cfg.Resource.Tracer.Filename = resolved } @@ -213,6 +214,7 @@ func (i input) run(env v2.Context, src *source, cursor map[string]interface{}, p Value: cfg.Auth.Token.Value, } } + wantDump := cfg.FailureDump.enabled() && cfg.FailureDump.Filename != "" doCov := cfg.RecordCoverage && log.IsDebug() httpOptions := lib.HTTPOptions{ @@ -864,6 +866,15 @@ func newClient(ctx context.Context, cfg config, log *logp.Logger, reg *monitorin Password: cfg.Auth.Digest.Password, NoReuse: noReuse, } + } else if cfg.Auth.AWS.IsEnabled() { + // this transport runs after the other ones (the other ones wrap this one); just to be on the safe side. + // If any of the other transports add any header, it must happen before the signing. + tr, err := aws.InitializeSignerTransport(*cfg.Auth.AWS, log, c.Transport) + if err != nil { + log.Errorw("failed to initialize aws config failed for signer", "error", err) + return nil, nil, err + } + c.Transport = tr } var trace *httplog.LoggingRoundTripper @@ -1012,7 +1023,7 @@ type socketDialer struct { } func (d socketDialer) Dial(_, _ string) (net.Conn, error) { - return net.Dial("unix", d.path) + return net.Dial("unix", d.path) //nolint:noctx // unix socket dial; no context propagation needed } func (d socketDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { @@ -1334,7 +1345,7 @@ func test(url *url.URL) error { return "80" }() - _, err := net.DialTimeout("tcp", net.JoinHostPort(url.Hostname(), port), time.Second) + _, err := net.DialTimeout("tcp", net.JoinHostPort(url.Hostname(), port), time.Second) //nolint:noctx // connectivity test; explicit timeout used instead of context if err != nil { return fmt.Errorf("url %q is unreachable: %w", url, err) } diff --git a/x-pack/filebeat/input/cel/input_test.go b/x-pack/filebeat/input/cel/input_test.go index afb7c1418b22..d8c84014adcf 100644 --- a/x-pack/filebeat/input/cel/input_test.go +++ b/x-pack/filebeat/input/cel/input_test.go @@ -19,6 +19,7 @@ import ( "path/filepath" "reflect" "runtime" + "strings" "sync" "testing" "time" @@ -615,9 +616,9 @@ var inputTests = []struct { ` - io.ReadAll(r.Body) + io.ReadAll(r.Body) //nolint:errcheck // No point checking errors in test server. r.Body.Close() - w.Write([]byte(text)) + w.Write([]byte(text)) //nolint:errcheck // No point checking errors in test server. }) server := httptest.NewServer(r) config["resource.url"] = server.URL @@ -749,7 +750,7 @@ var inputTests = []struct { msg = fmt.Sprintf(`{"error":"expected method was %#q"}`, http.MethodGet) } - w.Write([]byte(msg)) + w.Write([]byte(msg)) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -797,7 +798,7 @@ var inputTests = []struct { msg = fmt.Sprintf(`{"error":"expected method was %#q"}`, http.MethodGet) } - w.Write([]byte(msg)) + w.Write([]byte(msg)) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -830,7 +831,7 @@ var inputTests = []struct { }, handler: func(w http.ResponseWriter, r *http.Request) { enc := json.NewEncoder(w) - enc.Encode(map[string][]any{"events": {r.Header.Get("foo")}}) + enc.Encode(map[string][]any{"events": {r.Header.Get("foo")}}) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -851,7 +852,7 @@ var inputTests = []struct { `, }, handler: func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) + w.Write([]byte("hello")) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -872,7 +873,7 @@ var inputTests = []struct { `, }, handler: func(w http.ResponseWriter, r *http.Request) { - w.Write([]byte("hello")) + w.Write([]byte("hello")) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -911,7 +912,7 @@ var inputTests = []struct { msg = fmt.Sprintf(`{"error":"expected method was %#q"}`, http.MethodGet) } - w.Write([]byte(msg)) + w.Write([]byte(msg)) //nolint:errcheck // No point checking errors in test server. }, want: []map[string]interface{}{ { @@ -1890,6 +1891,44 @@ var inputTests = []struct { }, }, + { + name: "Auth AWS V4 Signer", + server: func(t *testing.T, h http.HandlerFunc, config map[string]interface{}) { + s := httptest.NewServer(h) + config["resource.url"] = s.URL + t.Cleanup(s.Close) + }, + config: map[string]interface{}{ + "interval": 1, + "auth.aws.access_key_id": "AKIAIOSFODNN7EXAMPLE", + "auth.aws.secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "auth.aws.default_region": "us-east-1", + "auth.aws.service_name": "guardduty", + "program": ` + bytes(get(state.url).Body).as(body, { + "events": [body.decode_json()] + }) + `, + }, + handler: awsAuthHandler("AKIAIOSFODNN7EXAMPLE", defaultHandler(http.MethodGet, "")), + want: []map[string]interface{}{ + { + "hello": []interface{}{ + map[string]interface{}{ + "world": "moon", + }, + map[string]interface{}{ + "space": []interface{}{ + map[string]interface{}{ + "cake": "pumpkin", + }, + }, + }, + }, + }, + }, + }, + // Multi-step requests. { name: "simple_multistep_GET_request", @@ -2270,7 +2309,7 @@ func TestInput(t *testing.T) { id := "test_id:" + test.name v2Ctx := v2.Context{ - Logger: logp.NewLogger("cel_test"), + Logger: logp.NewLogger("cel_test"), //nolint:forbidigo // test helper; no logp.Logger parameter available ID: id, IDWithoutName: id, Cancelation: ctx, @@ -2405,7 +2444,7 @@ func newChainTestServer(serve func(http.Handler) *httptest.Server) func(*testing func newV2Context() (v2.Context, func()) { ctx, cancel := context.WithCancel(context.Background()) return v2.Context{ - Logger: logp.NewLogger("httpjson_test"), + Logger: logp.NewLogger("httpjson_test"), //nolint:forbidigo // test helper; no logp.Logger parameter available ID: "test_id", Cancelation: ctx, }, cancel @@ -2482,7 +2521,6 @@ func retryHandler() http.HandlerFunc { } } -//nolint:errcheck // No point checking errors in test server. func tokenAuthHandler(want string, handle http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { auth := r.Header.Get("Authorization") @@ -2495,6 +2533,24 @@ func tokenAuthHandler(want string, handle http.HandlerFunc) http.HandlerFunc { } } +func awsAuthHandler(expectedTokenID string, handle http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/", expectedTokenID)) { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + amzDate := r.Header.Get("X-Amz-Date") + if amzDate == "" { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + handle(w, r) + } +} + //nolint:errcheck // No point checking errors in test server. func digestAuthHandler(user, pass, realm, nonce string, handle http.HandlerFunc) http.HandlerFunc { chal := &digest.Challenge{ diff --git a/x-pack/filebeat/input/httpjson/config_auth.go b/x-pack/filebeat/input/httpjson/config_auth.go index 4021d0c29e88..6dacf0773c2e 100644 --- a/x-pack/filebeat/input/httpjson/config_auth.go +++ b/x-pack/filebeat/input/httpjson/config_auth.go @@ -20,15 +20,27 @@ import ( "golang.org/x/oauth2/google" "github.com/elastic/beats/v7/libbeat/common" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" ) type authConfig struct { - Basic *basicAuthConfig `config:"basic"` - OAuth2 *oAuth2Config `config:"oauth2"` + Basic *basicAuthConfig `config:"basic"` + OAuth2 *oAuth2Config `config:"oauth2"` + AWS *aws.SignerInputConfig `config:"aws"` } func (c authConfig) Validate() error { - if c.Basic.isEnabled() && c.OAuth2.isEnabled() { + var n int + if c.Basic.isEnabled() { + n++ + } + if c.OAuth2.isEnabled() { + n++ + } + if c.AWS.IsEnabled() { + n++ + } + if n > 1 { return errors.New("only one kind of auth can be enabled") } return nil @@ -165,7 +177,7 @@ func (o *oAuth2Config) client(ctx context.Context, client *http.Client) (*http.C var creds *google.Credentials var err error if len(o.GoogleCredentialsJSON) != 0 { - creds, err = google.CredentialsFromJSON(ctx, o.GoogleCredentialsJSON, o.Scopes...) + creds, err = google.CredentialsFromJSON(ctx, o.GoogleCredentialsJSON, o.Scopes...) //nolint:staticcheck // deprecated but no suitable replacement available if err != nil { return nil, fmt.Errorf("oauth2 client: error loading credentials: %w", err) } diff --git a/x-pack/filebeat/input/httpjson/input.go b/x-pack/filebeat/input/httpjson/input.go index 7809421ad1e6..25dab16a660f 100644 --- a/x-pack/filebeat/input/httpjson/input.go +++ b/x-pack/filebeat/input/httpjson/input.go @@ -38,6 +38,7 @@ import ( "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httplog" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/httpmon" "github.com/elastic/beats/v7/x-pack/filebeat/input/internal/private" + "github.com/elastic/beats/v7/x-pack/libbeat/common/aws" "github.com/elastic/elastic-agent-libs/logp" "github.com/elastic/elastic-agent-libs/mapstr" "github.com/elastic/elastic-agent-libs/monitoring" @@ -105,7 +106,7 @@ type redact struct { func (r redact) MarshalLogObject(enc zapcore.ObjectEncoder) error { v, err := private.Redact(r.value, "", r.fields) if err != nil { - return fmt.Errorf("could not redact value: %v", err) + return fmt.Errorf("could not redact value: %w", err) } return v.MarshalLogObject(enc) } @@ -163,7 +164,7 @@ func test(url *url.URL) error { return "80" }() - _, err := net.DialTimeout("tcp", net.JoinHostPort(url.Hostname(), port), time.Second) + _, err := net.DialTimeout("tcp", net.JoinHostPort(url.Hostname(), port), time.Second) //nolint:noctx // test-only helper; timeout is explicit and context is not needed if err != nil { return fmt.Errorf("url %q is unreachable", url) } @@ -191,7 +192,7 @@ func run(ctx v2.Context, cfg config, pub inputcursor.Publisher, crsr *inputcurso return err } if !ok { - return fmt.Errorf("request tracer path %q must be within %q path", path, paths.Resolve(paths.Logs, inputName)) + return fmt.Errorf("request tracer path %q must be within %q path", path, paths.Resolve(paths.Logs, inputName)) //nolint:forbidigo // no per-beat path instance available here } cfg.Request.Tracer.Filename = resolved @@ -207,6 +208,7 @@ func run(ctx v2.Context, cfg config, pub inputcursor.Publisher, crsr *inputcurso } metrics := newInputMetrics(reg, ctx.Logger) + client, err := newHTTPClient(stdCtx, cfg.Auth, cfg.Request, ctx, log, reg, nil) if err != nil { ctx.UpdateStatus(status.Failed, "failed to create HTTP client: "+err.Error()) @@ -304,7 +306,22 @@ func newHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *request client *http.Client err error ) - if authCfg.OAuth2.isEnabled() { + switch { + case authCfg.AWS.IsEnabled(): + client, err = newNetHTTPClient(ctx, requestCfg, log, reg) + if err != nil { + log.Errorw("creation of initial http client failed", "error", err) + return nil, err + } + + log.Debugw("creating signer", "region", authCfg.AWS.DefaultRegion, "service", authCfg.AWS.ServiceName) + tr, err := aws.InitializeSignerTransport(*authCfg.AWS, log, client.Transport) + if err != nil { + log.Errorw("failed to initialize aws config failed for signer", "error", err) + return nil, err + } + client.Transport = tr + case authCfg.OAuth2.isEnabled(): client = authCfg.OAuth2.prepared if client == nil { client, err = newNetHTTPClient(ctx, requestCfg, log, reg) @@ -317,7 +334,7 @@ func newHTTPClient(ctx context.Context, authCfg *authConfig, requestCfg *request } authCfg.OAuth2.prepared = client } - } else { + default: client, err = newNetHTTPClient(ctx, requestCfg, log, reg) if err != nil { return nil, err @@ -444,7 +461,7 @@ type socketDialer struct { } func (d socketDialer) Dial(_, _ string) (net.Conn, error) { - return net.Dial("unix", d.path) + return net.Dial("unix", d.path) //nolint:noctx // unix socket dial; no context propagation needed } func (d socketDialer) DialContext(ctx context.Context, _, _ string) (net.Conn, error) { diff --git a/x-pack/filebeat/input/httpjson/input_test.go b/x-pack/filebeat/input/httpjson/input_test.go index 7a18b704d36a..08683ad3aa63 100644 --- a/x-pack/filebeat/input/httpjson/input_test.go +++ b/x-pack/filebeat/input/httpjson/input_test.go @@ -12,6 +12,7 @@ import ( "net/http/httptest" "os" "path/filepath" + "strings" "testing" "time" @@ -653,6 +654,24 @@ var testCases = []struct { handler: oauth2Handler, expected: []string{`{"hello": "world"}`}, }, + { + name: "aws auth", + setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { + server := httptest.NewServer(h) + config["request.url"] = server.URL + t.Cleanup(server.Close) + }, + baseConfig: map[string]interface{}{ + "interval": 1, + "request.method": http.MethodGet, + "auth.aws.access_key_id": "AKIAIOSFODNN7EXAMPLE", + "auth.aws.secret_access_key": "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", + "auth.aws.default_region": "us-east-1", + "auth.aws.service_name": "guardduty", + }, + handler: awsAuthHandler("AKIAIOSFODNN7EXAMPLE", defaultHandler(http.MethodGet, "", "")), + expected: []string{`{"hello":[{"world":"moon"},{"space":[{"cake":"pumpkin"}]}]}`}, + }, { name: "request_transforms_can_access_state_from_previous_transforms", setupServer: func(t testing.TB, h http.HandlerFunc, config map[string]interface{}) { @@ -1528,9 +1547,9 @@ var testCases = []struct { ` - io.ReadAll(r.Body) + _, _ = io.ReadAll(r.Body) r.Body.Close() - w.Write([]byte(text)) + _, _ = w.Write([]byte(text)) }) server := httptest.NewServer(r) config["request.url"] = server.URL @@ -1603,7 +1622,7 @@ var testCases = []struct { } func TestInput(t *testing.T) { - logp.TestingSetup() + logp.TestingSetup() //nolint:staticcheck // deprecated but logptest.NewTestingLogger is not yet used here for _, test := range testCases { t.Run(test.name, func(t *testing.T) { @@ -1683,7 +1702,7 @@ func TestInput(t *testing.T) { case got := <-chanClient.Channel: val, err := got.Fields.GetValue("message") assert.NoError(t, err) - assert.JSONEq(t, test.expected[receivedCount], val.(string)) + assert.JSONEq(t, test.expected[receivedCount], val.(string)) //nolint:errcheck // type assertion on known string value receivedCount += 1 if receivedCount == len(test.expected) { cancel() @@ -1849,7 +1868,7 @@ func newChainPaginationTestServer( func newV2Context(id string) (v2.Context, func()) { ctx, cancel := context.WithCancel(context.Background()) return v2.Context{ - Logger: logp.NewLogger("httpjson_test"), + Logger: logp.NewLogger("httpjson_test"), //nolint:forbidigo // test helper; no logp.Logger parameter available ID: id, IDWithoutName: id, Cancelation: ctx, @@ -1889,6 +1908,24 @@ func defaultHandler(expectedMethod, expectedBody, msg string) http.HandlerFunc { } } +func awsAuthHandler(expectedTokenID string, handle http.HandlerFunc) http.HandlerFunc { + return func(w http.ResponseWriter, r *http.Request) { + authHeader := r.Header.Get("Authorization") + if !strings.HasPrefix(authHeader, fmt.Sprintf("AWS4-HMAC-SHA256 Credential=%s/", expectedTokenID)) { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + amzDate := r.Header.Get("X-Amz-Date") + if amzDate == "" { + http.Error(w, `{"error":"not authorized"}`, http.StatusBadRequest) + return + } + + handle(w, r) + } +} + func rateLimitHandler() http.HandlerFunc { var isRetry bool return func(w http.ResponseWriter, r *http.Request) { diff --git a/x-pack/libbeat/common/aws/cloud_connectors.go b/x-pack/libbeat/common/aws/cloud_connectors.go new file mode 100644 index 000000000000..841183dea70e --- /dev/null +++ b/x-pack/libbeat/common/aws/cloud_connectors.go @@ -0,0 +1,111 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "errors" + "fmt" + "os" + "time" + + awssdk "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials/stscreds" + "github.com/aws/aws-sdk-go-v2/service/sts" + + "github.com/elastic/elastic-agent-libs/logp" +) + +// These env vars are provided by agentless controller when the cloud connectors flow is enabled. +const ( + CloudConnectorsGlobalRoleEnvVar = "CLOUD_CONNECTORS_GLOBAL_ROLE" + CloudConnectorsJWTPathEnvVar = "CLOUD_CONNECTORS_ID_TOKEN_FILE" + CloudConnectorsCloudResourceIDEnvVar = "CLOUD_RESOURCE_ID" +) + +// CloudConnectorsConfig is the config for the cloud connectors flow +type CloudConnectorsConfig struct { + ElasticGlobalRoleARN string + IDTokenPath string + CloudResourceID string +} + +func parseCloudConnectorsConfigFromEnv() (CloudConnectorsConfig, error) { + cc := CloudConnectorsConfig{ + ElasticGlobalRoleARN: os.Getenv(CloudConnectorsGlobalRoleEnvVar), + IDTokenPath: os.Getenv(CloudConnectorsJWTPathEnvVar), + CloudResourceID: os.Getenv(CloudConnectorsCloudResourceIDEnvVar), + } + + var errs []error + + if cc.ElasticGlobalRoleARN == "" { + errs = append(errs, errors.New("elastic global role arn is not configured")) + } + if cc.IDTokenPath == "" { + errs = append(errs, errors.New("id token path is not configured")) + } + if cc.CloudResourceID == "" { + errs = append(errs, errors.New("cloud resource id is not configured")) + } + + if len(errs) > 0 { + return CloudConnectorsConfig{}, fmt.Errorf("cloud connectors config is invalid: %w", errors.Join(errs...)) + } + + return cc, nil +} + +const defaultIntermediateDuration = 20 * time.Minute + +func addCloudConnectorsCredentials(config ConfigAWS, cloudConnectorsConfig CloudConnectorsConfig, awsConfig *awssdk.Config, logger *logp.Logger) { + logger = logger.Named("addCloudConnectorsCredentials") + logger.Debug("Switching credentials provider to Cloud Connectors") + + addCredentialsChain( + awsConfig, + + // Step 1: Assume the Elastic Global Role with web identity using the ID token provided by the agentless OIDC issuer. + func(c awssdk.Config) awssdk.CredentialsProvider { + provider := stscreds.NewWebIdentityRoleProvider( + sts.NewFromConfig(c), // client uses credentials from previous config. + cloudConnectorsConfig.ElasticGlobalRoleARN, + stscreds.IdentityTokenFile(cloudConnectorsConfig.IDTokenPath), + func(opt *stscreds.WebIdentityRoleOptions) { + opt.Duration = defaultIntermediateDuration + }, + ) + return awssdk.NewCredentialsCache(provider) + }, + + // Step 2: Assume the remote role (the user's configured role), using the previously assumed role in the chain. + func(c awssdk.Config) awssdk.CredentialsProvider { + assumeRoleProvider := stscreds.NewAssumeRoleProvider( + sts.NewFromConfig(c), // client uses credentials from previous config. + config.RoleArn, + func(aro *stscreds.AssumeRoleOptions) { + aro.Duration = config.AssumeRoleDuration + if config.ExternalID != "" { + aro.ExternalID = awssdk.String(cloudConnectorsExternalID(cloudConnectorsConfig.CloudResourceID, config.ExternalID)) + } + }, + ) + return awssdk.NewCredentialsCache(assumeRoleProvider, func(options *awssdk.CredentialsCacheOptions) { + if config.AssumeRoleExpiryWindow > 0 { + options.ExpiryWindow = config.AssumeRoleExpiryWindow + } + }) + }, + ) +} + +func cloudConnectorsExternalID(resourceID, externalIDPart string) string { + return fmt.Sprintf("%s-%s", resourceID, externalIDPart) +} + +func addCredentialsChain(awsConfig *awssdk.Config, chain ...func(awssdk.Config) awssdk.CredentialsProvider) { + for _, fn := range chain { + awsConfig.Credentials = fn(*awsConfig) + } +} diff --git a/x-pack/libbeat/common/aws/cloud_connectors_test.go b/x-pack/libbeat/common/aws/cloud_connectors_test.go new file mode 100644 index 000000000000..11f91f5961cf --- /dev/null +++ b/x-pack/libbeat/common/aws/cloud_connectors_test.go @@ -0,0 +1,182 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "context" + "fmt" + "io" + "net/url" + "os" + "path" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/aws/aws-sdk-go-v2/service/sts/types" + "github.com/aws/smithy-go/middleware" + smithyhttp "github.com/aws/smithy-go/transport/http" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/logp/logptest" +) + +func TestAddCloudConnectorsCredentials(t *testing.T) { + config := ConfigAWS{ + RoleArn: "arn:aws:iam::123456789012:role/customer-role", + ExternalID: "external-id-456", + AssumeRoleDuration: 2 * time.Hour, + AssumeRoleExpiryWindow: 10 * time.Minute, + } + cloudConnectorsConfig := CloudConnectorsConfig{ + ElasticGlobalRoleARN: "arn:aws:iam::999999999999:role/elastic-global-role", + CloudResourceID: "abcd1234", + } + tokenFileContent := "abc123" + + tmpDir := t.TempDir() + pth := path.Join(tmpDir, "id_token") + _ = os.WriteFile(path.Join(tmpDir, "id_token"), []byte(tokenFileContent), 0o644) + cloudConnectorsConfig.IDTokenPath = pth + + // Create a base AWS config + awsConfig := &aws.Config{ + Region: "us-east-1", + BaseEndpoint: aws.String("https://aws.mock"), + } + + // Create a test logger + logger := logptest.NewTestingLogger(t, "") + + // mock responses + receivedCalls := 0 + awsConfig.APIOptions = append(awsConfig.APIOptions, func(stack *middleware.Stack) error { + return stack.Finalize.Add( + middleware.FinalizeMiddlewareFunc( + "mock", + func(ctx context.Context, in middleware.FinalizeInput, next middleware.FinalizeHandler) (middleware.FinalizeOutput, middleware.Metadata, error) { + req, is := in.Request.(*smithyhttp.Request) + require.Truef(t, is, "request expected to be of type *smithyhttp.Request, got: %T", in.Request) + receivedCalls++ + bd, err := io.ReadAll(req.GetStream()) + assert.NoError(t, req.RewindStream()) + assert.NoError(t, err) + body := string(bd) + + switch receivedCalls { + + // Expect the first request to be AssumeRoleWithWebIdentity + case 1: + q, err := url.ParseQuery(body) + assert.NoError(t, err) + assert.Equal(t, "AssumeRoleWithWebIdentity", q.Get("Action")) + assert.Equal(t, "1200", q.Get("DurationSeconds")) + assert.Equal(t, cloudConnectorsConfig.ElasticGlobalRoleARN, q.Get("RoleArn")) + assert.Equal(t, tokenFileContent, q.Get("WebIdentityToken")) + return middleware.FinalizeOutput{ + Result: &sts.AssumeRoleWithWebIdentityOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("AKIAFAKEEXAMPLE00001"), + SecretAccessKey: aws.String("FAKEwJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY1"), + SessionToken: aws.String("FwoGZXIvYXdzEFAaDFAKESESSIONTOKENEXAMPLE1"), + Expiration: aws.Time(time.Now().Add(defaultIntermediateDuration)), + }, + }, + }, middleware.Metadata{}, nil + + // Expect the second request to be AssumeRole + case 2: + q, err := url.ParseQuery(body) + assert.NoError(t, err) + assert.Equal(t, "AssumeRole", q.Get("Action")) + assert.Equal(t, "7200", q.Get("DurationSeconds")) + assert.Equal(t, cloudConnectorsExternalID(cloudConnectorsConfig.CloudResourceID, config.ExternalID), q.Get("ExternalId")) + assert.Equal(t, config.RoleArn, q.Get("RoleArn")) + return middleware.FinalizeOutput{ + Result: &sts.AssumeRoleOutput{ + Credentials: &types.Credentials{ + AccessKeyId: aws.String("AKIAFAKEEXAMPLE00002"), + SecretAccessKey: aws.String("FAKEwJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY2"), + SessionToken: aws.String("FwoGZXIvYXdzEFAaDFAKESESSIONTOKENEXAMPLE2"), + Expiration: aws.Time(time.Now().Add(defaultIntermediateDuration)), + }, + }, + }, middleware.Metadata{}, nil + + default: + t.Fatal("unexpected aws sdk call") + return middleware.FinalizeOutput{}, middleware.Metadata{}, fmt.Errorf("unexpected operation") + } + }, + ), + middleware.After, + ) + }) + + // Call the function under test + addCloudConnectorsCredentials( + config, + cloudConnectorsConfig, + awsConfig, + logger, + ) + + // Verify that credentials provider was set + require.NotNil(t, awsConfig.Credentials, "credentials provider should be set") + + crd, err := awsConfig.Credentials.Retrieve(t.Context()) + require.NoError(t, err) + require.NotNil(t, crd) + require.Equal(t, 2, receivedCalls) +} + +func TestCloudConnectorsExternalID(t *testing.T) { + assert.Equal(t, "resource1-ext-id", cloudConnectorsExternalID("resource1", "ext-id")) + assert.Equal(t, "abc123-external-id-456", cloudConnectorsExternalID("abc123", "external-id-456")) + assert.Equal(t, "single-", cloudConnectorsExternalID("single", "")) // format is always "resourceID-externalIDPart" +} + +func TestParseCloudConnectorsConfigFromEnv(t *testing.T) { + t.Run("happy_path", func(t *testing.T) { + t.Setenv(CloudConnectorsGlobalRoleEnvVar, "arn:aws:iam::999999999999:role/elastic-global-role") + t.Setenv(CloudConnectorsJWTPathEnvVar, "/path/token") + t.Setenv(CloudConnectorsCloudResourceIDEnvVar, "abc123") + + got, err := parseCloudConnectorsConfigFromEnv() + + require.NoError(t, err) + + assert.Equal( + t, + CloudConnectorsConfig{ + ElasticGlobalRoleARN: "arn:aws:iam::999999999999:role/elastic-global-role", + IDTokenPath: "/path/token", + CloudResourceID: "abc123", + }, + got, + ) + }) + + t.Run("missing config single", func(t *testing.T) { + t.Setenv(CloudConnectorsGlobalRoleEnvVar, "arn:aws:iam::999999999999:role/elastic-global-role") + t.Setenv(CloudConnectorsJWTPathEnvVar, "/path/token") + + got, err := parseCloudConnectorsConfigFromEnv() + + require.ErrorContains(t, err, "cloud resource id") + assert.Equal(t, CloudConnectorsConfig{}, got) + }) + + t.Run("missing config all", func(t *testing.T) { + got, err := parseCloudConnectorsConfigFromEnv() + + require.ErrorContains(t, err, "elastic global role") + require.ErrorContains(t, err, "id token") + require.ErrorContains(t, err, "cloud resource id") + assert.Equal(t, CloudConnectorsConfig{}, got) + }) +} diff --git a/x-pack/libbeat/common/aws/credentials.go b/x-pack/libbeat/common/aws/credentials.go index dc45a8c19769..e318d4bda1ee 100644 --- a/x-pack/libbeat/common/aws/credentials.go +++ b/x-pack/libbeat/common/aws/credentials.go @@ -52,6 +52,10 @@ type ConfigAWS struct { // AssumeRoleExpiryWindow will allow the credentials to trigger refreshing prior to the credentials // actually expiring. If expiry_window is less than or equal to zero, the setting is ignored. AssumeRoleExpiryWindow time.Duration `config:"assume_role.expiry_window"` + + // UseCloudConnectors indicates whether the cloud connectors flow is used. + // If this is true, the InitializeAWSConfig should initialize the AWS cloud connector role chaining flow. + UseCloudConnectors bool `config:"use_cloud_connectors"` } // InitializeAWSConfig function creates the awssdk.Config object from the provided config @@ -66,10 +70,19 @@ func InitializeAWSConfig(beatsConfig ConfigAWS, logger *logp.Logger) (awssdk.Con } // Assume IAM role if iam_role config parameter is given - if beatsConfig.RoleArn != "" { + if beatsConfig.RoleArn != "" && !beatsConfig.UseCloudConnectors { addAssumeRoleProviderToAwsConfig(beatsConfig, &awsConfig, logger) } + // If cloud connectors method is selected from config, initialize the role chaining. + if beatsConfig.UseCloudConnectors { + cloudConnectorsConfig, err := parseCloudConnectorsConfigFromEnv() + if err != nil { + return awsConfig, err + } + addCloudConnectorsCredentials(beatsConfig, cloudConnectorsConfig, &awsConfig, logger) + } + var proxy func(*http.Request) (*url.URL, error) if beatsConfig.ProxyUrl != "" { proxyUrl, err := httpcommon.NewProxyURIFromString(beatsConfig.ProxyUrl) diff --git a/x-pack/libbeat/common/aws/credentials_test.go b/x-pack/libbeat/common/aws/credentials_test.go index ce09aeac98b5..71cb3e91b57c 100644 --- a/x-pack/libbeat/common/aws/credentials_test.go +++ b/x-pack/libbeat/common/aws/credentials_test.go @@ -9,12 +9,36 @@ import ( "net/http" "testing" + "github.com/aws/aws-sdk-go-v2/aws" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "github.com/elastic/elastic-agent-libs/logp/logptest" "github.com/elastic/elastic-agent-libs/transport/tlscommon" ) +func TestInitializeAWSConfigCloudConnectors(t *testing.T) { + t.Setenv(CloudConnectorsGlobalRoleEnvVar, "arn:aws:iam::999999999999:role/elastic-global-role") + t.Setenv(CloudConnectorsJWTPathEnvVar, "/path/token") + t.Setenv(CloudConnectorsCloudResourceIDEnvVar, "abc123") + + inputConfig := ConfigAWS{ + RoleArn: "arn:aws:iam::123456789012:role/customer-role", + ExternalID: "external-id-456", + UseCloudConnectors: true, + } + + awsConfig, err := InitializeAWSConfig(inputConfig, logptest.NewTestingLogger(t, "")) + assert.NoError(t, err) + + // we cannot append to APIOptions at this point (and mock the chain responses) + // because a copy of config has already been passed to each sts client. + // So lets just check that .Credentials is CredentialsCache (so cloud connectors init was run). + c, isCredCache := awsConfig.Credentials.(*aws.CredentialsCache) + require.True(t, isCredCache) + require.NotNil(t, c) +} + func TestInitializeAWSConfig(t *testing.T) { inputConfig := ConfigAWS{ AccessKeyID: "123", diff --git a/x-pack/libbeat/common/aws/signer.go b/x-pack/libbeat/common/aws/signer.go new file mode 100644 index 000000000000..74b17f7f27e7 --- /dev/null +++ b/x-pack/libbeat/common/aws/signer.go @@ -0,0 +1,207 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "bytes" + "crypto/sha256" + "encoding/hex" + "errors" + "fmt" + "io" + "net/http" + "strings" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + v4 "github.com/aws/aws-sdk-go-v2/aws/signer/v4" + awslogging "github.com/aws/smithy-go/logging" + + "github.com/elastic/elastic-agent-libs/logp" +) + +// SignerInputConfig is the top-level configuration for the input aws auth method, +// used both in CEL and HTTP JSON inputs. It wraps the [ConfigAWS] inlined. +type SignerInputConfig struct { + // Enabled indicates whether this auth method is used. + // [SignerInputConfig.IsEnabled] implements the logic of the check. + Enabled *bool `config:"enabled"` + + // ServiceName can be used to optionally set the AWS service name that AWS V4 signer will sign for. + // If not value is set here, the signer will try to infer the service from the url of the request. + ServiceName string `config:"service_name"` + + // ConfigAWS is the inline wrapping of the rest of the [ConfigAWS] that can be use in the input config. + ConfigAWS `config:",inline"` +} + +// IsEnabled returns true if the `enable` field is set to true in the yaml or if it is nil. +func (c *SignerInputConfig) IsEnabled() bool { + return c != nil && (c.Enabled == nil || *c.Enabled) +} + +// SignerTransport implements [http.RoundTripper] interface +// and signs requests with aws v4 signer before send them to the next roundtripper. +// If the `serviceName` and `region` are not set, the signer will try to infer them from each request's URL. +type SignerTransport struct { + next http.RoundTripper + credentials aws.CredentialsProvider + signer *v4.Signer + logger *logp.Logger + serviceName string + region string + now func() time.Time // we don't use [time.Now] directly, so we can mock time in tests. +} + +// InitializeSignerTransport initializes first the AWS config using the [InitializeAWSConfig] and the [ConfigAWS] and then, +// initializes the RoundTripper using the AWS credentials from the previously initialized AWS config. +func InitializeSignerTransport(cfg SignerInputConfig, logger *logp.Logger, nextTransport http.RoundTripper) (*SignerTransport, error) { + awsConfig, err := InitializeAWSConfig(cfg.ConfigAWS, logger) + if err != nil { + return nil, err + } + + return initializeSignerTransport(logger, cfg.ServiceName, cfg.DefaultRegion, awsConfig.Credentials, nextTransport), nil +} + +func initializeSignerTransport(logger *logp.Logger, defaultServiceName string, defaultRegion string, credentials aws.CredentialsProvider, nextTransport http.RoundTripper) *SignerTransport { + return &SignerTransport{ + next: nextTransport, + credentials: credentials, + signer: v4.NewSigner(func(signer *v4.SignerOptions) { + signer.Logger = awslogging.LoggerFunc(func(classification awslogging.Classification, format string, v ...any) { + switch classification { + case awslogging.Debug: + logger.Debugf(format, v...) + case awslogging.Warn: + logger.Warnf(format, v...) + } + }) + }), + logger: logger, + serviceName: defaultServiceName, + region: defaultRegion, + now: time.Now, + } +} + +func (st *SignerTransport) RoundTrip(req *http.Request) (*http.Response, error) { + // resolve service name and region (if they are not configured) + serviceName, region, err := st.getServiceAndRegion(req) + if err != nil { + return nil, fmt.Errorf("error while getting service name and region: %w", err) + } + + // retrieve credentials + creds, err := st.credentials.Retrieve(req.Context()) + if err != nil { + return nil, fmt.Errorf("error while retrieving credentials: %w", err) + } + + // body hash + payloadHash, err := st.bodySHA256Hash(req) + if err != nil { + return nil, fmt.Errorf("error while calculating body hash: %w", err) + } + + // sign the request + err = st.signer.SignHTTP(req.Context(), creds, req, payloadHash, serviceName, region, st.now()) + if err != nil { + return nil, fmt.Errorf("error while signing the request: %w", err) + } + + // next transport + return st.next.RoundTrip(req) +} + +// bodySHA256Hash returns the sha256 hash of the request's body by reading a copy of the body. +// The request's Body remains readable and unmodified after this function returns. +func (st *SignerTransport) bodySHA256Hash(req *http.Request) (string, error) { + if req.Body == nil || req.Body == http.NoBody { + return hex.EncodeToString(sha256.New().Sum(nil)), nil + } + + // this is a copy of the original body + body, err := st.getBody(req) + if err != nil { + return "", err + } + + hash := sha256.New() + + if _, err := io.Copy(hash, body); err != nil { + return "", err + } + + return hex.EncodeToString(hash.Sum(nil)), body.Close() +} + +// getBody returns a copy of the request's body as a [io.ReadCloser]. +// The request's Body remains readable and unmodified after this function returns. +func (st *SignerTransport) getBody(req *http.Request) (io.ReadCloser, error) { + if req.GetBody != nil { + // [http.Request] GetBody dictates that a new copy of the body must be returned. + return req.GetBody() + } + + if req.Body == http.NoBody || req.Body == nil { + return req.Body, nil + } + + // If the GetBody does not exist we need to manually copy the body. + // In Beats use-case its not possible for this to happen, + // since, both in cel and httpjson, the request is initialized + // with *bytes.Buffer, *bytes.Reader or *strings.Reader as body, which gets GetBody initialized. + // httpjson: (x-pack/filebeat/input/httpjson/request.go newHTTPRequest) + // cel: mito repo (lib/http.go) + // We cover the edge case here by reading and copying the body manually. + bodyBytes, err := io.ReadAll(req.Body) + if err != nil { + return nil, fmt.Errorf("error while reading request body: %w", err) + } + if err := req.Body.Close(); err != nil { + st.logger.Warnf("error while closing copied body %s", err.Error()) + } + + // reset body to the request + req.Body = io.NopCloser(bytes.NewBuffer(bodyBytes)) + + return io.NopCloser(bytes.NewBuffer(bodyBytes)), nil +} + +// getServiceAndRegion returns the service name and the region for the upcoming request. +// If service name and region are configured with default values, those take precedence. +// Otherwise it will try to parse the values from [http.Request] Host value. +func (st *SignerTransport) getServiceAndRegion(req *http.Request) (serviceName, region string, err error) { + serviceName = st.serviceName + region = st.region + + if serviceName == "" || region == "" { + s, r, err := parseServiceAndRegionFromHost(req.Host) + if err != nil { + return "", "", err + } + if serviceName == "" { + serviceName = s + } + if region == "" { + region = r + } + } + + return serviceName, region, nil +} + +func parseServiceAndRegionFromHost(host string) (service, region string, err error) { + parts := strings.SplitN(host, ".", 4) + + if len(parts) < 4 { + return "", "", errMalformedHost + } + + return parts[0], parts[1], nil +} + +var errMalformedHost = errors.New("malformed host string") diff --git a/x-pack/libbeat/common/aws/signer_test.go b/x-pack/libbeat/common/aws/signer_test.go new file mode 100644 index 000000000000..69fd8886b7eb --- /dev/null +++ b/x-pack/libbeat/common/aws/signer_test.go @@ -0,0 +1,319 @@ +// Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one +// or more contributor license agreements. Licensed under the Elastic License; +// you may not use this file except in compliance with the Elastic License. + +package aws + +import ( + "bytes" + "io" + "net/http" + "testing" + "time" + + "github.com/aws/aws-sdk-go-v2/aws" + "github.com/aws/aws-sdk-go-v2/credentials" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/mock" + "github.com/stretchr/testify/require" + + "github.com/elastic/elastic-agent-libs/logp/logptest" +) + +func mockNow(v time.Time) func() time.Time { return func() time.Time { return v } } + +type mockRoundTripper struct { + mock.Mock + req *http.Request +} + +func (m *mockRoundTripper) RoundTrip(req *http.Request) (*http.Response, error) { + args := m.Called(req) + m.req = req // store the request for later assertions. + return args.Get(0).(*http.Response), args.Error(1) //nolint:errcheck // not needed here. +} + +func TestSignerTransportRoundTrip(t *testing.T) { + now := mockNow(time.Date(2025, time.October, 11, 16, 0, 0, 0, time.UTC)) + + // fake credentials received from this: https://docs.aws.amazon.com/STS/latest/APIReference/API_GetAccessKeyInfo.html + fakeStaticCreds := credentials.NewStaticCredentialsProvider("AKIAIOSFODNN7EXAMPLE", "wJalrXUtnFEMI/K7MDENG/bPxRfiCYEXAMPLEKEY", "session") + + tests := []struct { + name string + defaultServiceName string + defaultRegion string + url string + requestBody io.Reader + requestHeaders map[string]string + credentials aws.CredentialsProvider + now func() time.Time + initMockRoundTripper func(*mockRoundTripper) + expectError bool + expectedRequestHeaders map[string]string + expectedRequestBody []byte + }{ + { + name: "no body", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: http.NoBody, + requestHeaders: map[string]string{}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=a73ff41e90b3e54c8855dc53cb352c244f4cf39122838e4ded22eef0fde01095", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + }, + expectedRequestBody: []byte{}, + }, + { + name: "no body overwrite service name and region", + defaultServiceName: "guardduty", + defaultRegion: "us-east-1", + url: "https://guardduty2.us-east-2.amazonaws.com/detector/abc123/findings", + requestBody: http.NoBody, + requestHeaders: map[string]string{}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=host;x-amz-date;x-amz-security-token, Signature=2bc3ea894efa9703ec95cac0bdcd6a1067a64636058b66e88640af2dc06ff2dd", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + }, + expectedRequestBody: []byte{}, + }, + { + name: "with body", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: bytes.NewBuffer([]byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`)), + requestHeaders: map[string]string{}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=content-length;host;x-amz-date;x-amz-security-token, Signature=1cba6843418733071843e982a5e399eebfa3caeef3bae336ab4477abf42a9fb7", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + }, + expectedRequestBody: []byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`), + }, + { + name: "with body and headers", + defaultServiceName: "", + defaultRegion: "", + url: "https://guardduty.us-east-1.amazonaws.com/detector/abc123/findings", + requestBody: bytes.NewBuffer([]byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`)), + requestHeaders: map[string]string{"X-Extra-Header": "abc123"}, + credentials: fakeStaticCreds, + now: now, + initMockRoundTripper: func(mrt *mockRoundTripper) { + mrt.On("RoundTrip", mock.Anything).Return(&http.Response{}, nil).Once() + }, + expectError: false, + expectedRequestHeaders: map[string]string{ + "Authorization": "AWS4-HMAC-SHA256 Credential=AKIAIOSFODNN7EXAMPLE/20251011/us-east-1/guardduty/aws4_request, SignedHeaders=content-length;host;x-amz-date;x-amz-security-token;x-extra-header, Signature=a9ae9766395c5749fca156baf9c65ef78d4d3053866299db48838c3546aaeb25", + "X-Amz-Date": "20251011T160000Z", + "X-Amz-Security-Token": "session", + "X-Extra-Header": "abc123", + }, + expectedRequestBody: []byte(`{"findingIds": [ "abc" ], "sortCriteria": {"attributeName":"updatedAt","orderBy":"ASC"}}`), + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + + m := mockRoundTripper{} + if tc.initMockRoundTripper != nil { + tc.initMockRoundTripper(&m) + } + + st := initializeSignerTransport(logger, tc.defaultServiceName, tc.defaultRegion, tc.credentials, &m) + st.now = tc.now + + assert.Equal(t, tc.defaultServiceName, st.serviceName) + assert.Equal(t, tc.defaultRegion, st.region) + + req, err := http.NewRequestWithContext(t.Context(), http.MethodGet, tc.url, tc.requestBody) + require.NoError(t, err) + for k, v := range tc.requestHeaders { + req.Header.Set(k, v) + } + + _, err = st.RoundTrip(req) //nolint:bodyclose // we don't actually have response body here + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, err) + + gotHeaders := map[string]string{} + for k := range m.req.Header { + gotHeaders[k] = m.req.Header.Get(k) + } + assert.Equal(t, tc.expectedRequestHeaders, gotHeaders) + + // ensure that request's body is readable (and not consumed) after the hash operation. + b, err := io.ReadAll(req.Body) + require.NoError(t, err) + assert.Equal(t, tc.expectedRequestBody, b) + }) + } +} + +func TestBodySHA256Hash(t *testing.T) { + tests := []struct { + name string + body io.Reader + expectedHash string + expectError bool + }{ + { + name: "no body", + body: http.NoBody, + expectedHash: "e3b0c44298fc1c149afbf4c8996fb92427ae41e4649b934ca495991b7852b855", + expectError: false, + }, + { + name: "with body that initializes GetBody", + body: bytes.NewReader([]byte(`"abc"`)), + expectedHash: "6cc43f858fbb763301637b5af970e2a46b46f461f27e5a0f41e009c59b827b25", + expectError: false, + }, + { + name: "with body without initialized GetBody", + body: io.NopCloser(bytes.NewReader([]byte(`"abc"`))), + expectedHash: "6cc43f858fbb763301637b5af970e2a46b46f461f27e5a0f41e009c59b827b25", + expectError: false, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + st := SignerTransport{ + next: nil, + credentials: nil, + signer: nil, + logger: logger, + serviceName: "", + region: "", + now: time.Now, + } + + req, err := http.NewRequestWithContext(t.Context(), "GET", "sample.amazonaws.com", tc.body) + require.NoError(t, err) + + gotHash, gotErr := st.bodySHA256Hash(req) + + assert.Equal(t, tc.expectedHash, gotHash, "hash of body is different than expected") + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, gotErr) + }) + } +} + +func TestGetServiceAndRegion(t *testing.T) { + tests := []struct { + name string + configuredServiceName string + configuredRegion string + requestHost string + expectedServiceName string + expectedRegion string + expectError bool + }{ + { + name: "extract from host", + configuredServiceName: "", + configuredRegion: "", + requestHost: "guardduty.us-east-1.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "configured values take precedence", + configuredServiceName: "guardduty", + configuredRegion: "us-east-1", + requestHost: "abc.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "service name configured region from url", + configuredServiceName: "guardduty", + configuredRegion: "", + requestHost: "abc.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-2", + expectError: false, + }, + { + name: "service name from url region configured", + configuredServiceName: "", + configuredRegion: "us-east-1", + requestHost: "guardduty.us-east-2.amazonaws.com", + expectedServiceName: "guardduty", + expectedRegion: "us-east-1", + expectError: false, + }, + { + name: "malformed host", + configuredServiceName: "", + configuredRegion: "", + requestHost: "amazonaws.com", + expectedServiceName: "", + expectedRegion: "", + expectError: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + logger := logptest.NewTestingLogger(t, "") + st := SignerTransport{ + next: nil, + credentials: nil, + signer: nil, + logger: logger, + serviceName: tc.configuredServiceName, + region: tc.configuredRegion, + now: time.Now, + } + + req := &http.Request{Host: tc.requestHost} + + gotServiceName, gotRegion, gotErr := st.getServiceAndRegion(req) + + assert.Equal(t, tc.expectedServiceName, gotServiceName, "service name is different than expected") + assert.Equal(t, tc.expectedRegion, gotRegion, "service name is different than expected") + errAssert := assert.NoError + if tc.expectError { + errAssert = assert.Error + } + errAssert(t, gotErr) + }) + } +}