Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move credentials endpoints to ecs-agent module #3698

Merged
merged 4 commits into from
May 17, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions agent/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/aws/amazon-ecs-agent/agent/dockerclient"
"github.com/aws/amazon-ecs-agent/agent/ec2"
"github.com/aws/amazon-ecs-agent/agent/utils"
commonutils "github.com/aws/amazon-ecs-agent/ecs-agent/utils"
"github.com/cihub/seelog"
)

Expand Down Expand Up @@ -204,7 +205,7 @@ func (cfg *Config) Merge(rhs Config) *Config {
leftField.Set(reflect.ValueOf(right.Field(i).Interface()))
}
default:
if utils.ZeroOrNil(leftField.Interface()) {
if commonutils.ZeroOrNil(leftField.Interface()) {
leftField.Set(reflect.ValueOf(right.Field(i).Interface()))
}
}
Expand Down Expand Up @@ -395,7 +396,7 @@ func (cfg *Config) checkMissingAndDepreciated() error {
fatalFields := []string{}
for i := 0; i < cfgElem.NumField(); i++ {
cfgField := cfgElem.Field(i)
if utils.ZeroOrNil(cfgField.Interface()) {
if commonutils.ZeroOrNil(cfgField.Interface()) {
missingTag := cfgStructField.Field(i).Tag.Get("missing")
if len(missingTag) == 0 {
continue
Expand Down Expand Up @@ -429,7 +430,7 @@ func (cfg *Config) complete() bool {
cfgElem := reflect.ValueOf(cfg).Elem()

for i := 0; i < cfgElem.NumField(); i++ {
if utils.ZeroOrNil(cfgElem.Field(i).Interface()) {
if commonutils.ZeroOrNil(cfgElem.Field(i).Interface()) {
return false
}
}
Expand Down Expand Up @@ -464,7 +465,7 @@ func fileConfig() (Config, error) {
}

// Handle any deprecated keys correctly here
if utils.ZeroOrNil(cfg.Cluster) && !utils.ZeroOrNil(cfg.ClusterArn) {
if commonutils.ZeroOrNil(cfg.Cluster) && !commonutils.ZeroOrNil(cfg.ClusterArn) {
cfg.Cluster = cfg.ClusterArn
}
return cfg, nil
Expand Down
9 changes: 5 additions & 4 deletions agent/handlers/task_server_setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@ import (
"github.com/aws/amazon-ecs-agent/agent/config"
"github.com/aws/amazon-ecs-agent/agent/engine/dockerstate"
agentAPITaskProtectionV1 "github.com/aws/amazon-ecs-agent/agent/handlers/agentapi/taskprotection/v1/handlers"
v1 "github.com/aws/amazon-ecs-agent/agent/handlers/v1"
v2 "github.com/aws/amazon-ecs-agent/agent/handlers/v2"
v3 "github.com/aws/amazon-ecs-agent/agent/handlers/v3"
v4 "github.com/aws/amazon-ecs-agent/agent/handlers/v4"
Expand All @@ -32,6 +31,8 @@ import (
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds"
tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1"
tmdsv2 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v2"
"github.com/cihub/seelog"
"github.com/gorilla/mux"
)
Expand Down Expand Up @@ -67,8 +68,8 @@ func taskServerSetup(credentialsManager credentials.Manager,
// to permanently redirect(301) to "/v3/metadata/task" handler
muxRouter.SkipClean(false)

muxRouter.HandleFunc(v1.CredentialsPath,
v1.CredentialsHandler(credentialsManager, auditLogger))
muxRouter.HandleFunc(tmdsv1.CredentialsPath,
tmdsv1.CredentialsHandler(credentialsManager, auditLogger))

v2HandlersSetup(muxRouter, state, ecsClient, statsEngine, cluster, credentialsManager, auditLogger, availabilityZone, containerInstanceArn)

Expand Down Expand Up @@ -97,7 +98,7 @@ func v2HandlersSetup(muxRouter *mux.Router,
auditLogger auditinterface.AuditLogger,
availabilityZone string,
containerInstanceArn string) {
muxRouter.HandleFunc(v2.CredentialsPath, v2.CredentialsHandler(credentialsManager, auditLogger))
muxRouter.HandleFunc(tmdsv2.CredentialsPath, tmdsv2.CredentialsHandler(credentialsManager, auditLogger))
muxRouter.HandleFunc(v2.ContainerMetadataPath, v2.TaskContainerMetadataHandler(state, ecsClient, cluster, availabilityZone, containerInstanceArn, false))
muxRouter.HandleFunc(v2.TaskMetadataPath, v2.TaskContainerMetadataHandler(state, ecsClient, cluster, availabilityZone, containerInstanceArn, false))
muxRouter.HandleFunc(v2.TaskWithTagsMetadataPath, v2.TaskContainerMetadataHandler(state, ecsClient, cluster, availabilityZone, containerInstanceArn, true))
Expand Down
13 changes: 7 additions & 6 deletions agent/handlers/task_server_setup_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@ import (
mock_credentials "github.com/aws/amazon-ecs-agent/ecs-agent/credentials/mocks"
mock_audit "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/mocks"
"github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/utils"
tmdsv1 "github.com/aws/amazon-ecs-agent/ecs-agent/tmds/handlers/v1"
"github.com/aws/aws-sdk-go/aws"
"github.com/docker/docker/api/types"
"github.com/golang/mock/gomock"
Expand Down Expand Up @@ -535,7 +536,7 @@ func TestInvalidPath(t *testing.T) {
// query parameters are not specified for the credentials endpoint.
func TestCredentialsV1RequestWithNoArguments(t *testing.T) {
msg := &utils.ErrorMessage{
Code: v1.ErrNoIDInRequest,
Code: tmdsv1.ErrNoIDInRequest,
Message: "CredentialsV1Request: No ID in the request",
HTTPErrorCode: http.StatusBadRequest,
}
Expand All @@ -546,7 +547,7 @@ func TestCredentialsV1RequestWithNoArguments(t *testing.T) {
// query parameters are not specified for the credentials endpoint.
func TestCredentialsV2RequestWithNoArguments(t *testing.T) {
msg := &utils.ErrorMessage{
Code: v1.ErrNoIDInRequest,
Code: tmdsv1.ErrNoIDInRequest,
Message: "CredentialsV2Request: No ID in the request",
HTTPErrorCode: http.StatusBadRequest,
}
Expand All @@ -557,7 +558,7 @@ func TestCredentialsV2RequestWithNoArguments(t *testing.T) {
// the credentials manager does not contain the credentials id specified in the query.
func TestCredentialsV1RequestWhenCredentialsIdNotFound(t *testing.T) {
expectedErrorMessage := &utils.ErrorMessage{
Code: v1.ErrInvalidIDInRequest,
Code: tmdsv1.ErrInvalidIDInRequest,
Message: fmt.Sprintf("CredentialsV1Request: Credentials not found"),
HTTPErrorCode: http.StatusBadRequest,
}
Expand All @@ -571,7 +572,7 @@ func TestCredentialsV1RequestWhenCredentialsIdNotFound(t *testing.T) {
// the credentials manager does not contain the credentials id specified in the query.
func TestCredentialsV2RequestWhenCredentialsIdNotFound(t *testing.T) {
expectedErrorMessage := &utils.ErrorMessage{
Code: v1.ErrInvalidIDInRequest,
Code: tmdsv1.ErrInvalidIDInRequest,
Message: fmt.Sprintf("CredentialsV2Request: Credentials not found"),
HTTPErrorCode: http.StatusBadRequest,
}
Expand All @@ -585,7 +586,7 @@ func TestCredentialsV2RequestWhenCredentialsIdNotFound(t *testing.T) {
// the credentials manager returns empty credentials.
func TestCredentialsV1RequestWhenCredentialsUninitialized(t *testing.T) {
expectedErrorMessage := &utils.ErrorMessage{
Code: v1.ErrCredentialsUninitialized,
Code: tmdsv1.ErrCredentialsUninitialized,
Message: fmt.Sprintf("CredentialsV1Request: Credentials uninitialized for ID"),
HTTPErrorCode: http.StatusServiceUnavailable,
}
Expand All @@ -599,7 +600,7 @@ func TestCredentialsV1RequestWhenCredentialsUninitialized(t *testing.T) {
// the credentials manager returns empty credentials.
func TestCredentialsV2RequestWhenCredentialsUninitialized(t *testing.T) {
expectedErrorMessage := &utils.ErrorMessage{
Code: v1.ErrCredentialsUninitialized,
Code: tmdsv1.ErrCredentialsUninitialized,
Message: fmt.Sprintf("CredentialsV2Request: Credentials uninitialized for ID"),
HTTPErrorCode: http.StatusServiceUnavailable,
}
Expand Down
19 changes: 13 additions & 6 deletions agent/logger/audit/audit_log_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (
"github.com/aws/amazon-ecs-agent/agent/config"
mock_infologger "github.com/aws/amazon-ecs-agent/agent/logger/audit/mocks"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
auditinterface "github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request"
"github.com/golang/mock/gomock"
"github.com/stretchr/testify/assert"
Expand Down Expand Up @@ -77,7 +78,8 @@ func TestWritingToAuditLog(t *testing.T) {
verifyAuditLogEntryResult(logLine, taskARN, dummyURLPath, t)
})

auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode, GetCredentialsEventType(dummyRoleType))
auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingToAuditLogV2(t *testing.T) {
Expand Down Expand Up @@ -108,7 +110,8 @@ func TestWritingToAuditLogV2(t *testing.T) {
verifyAuditLogEntryResult(logLine, taskARN, credentials.V2CredentialsPath, t)
})

auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode, GetCredentialsEventType(dummyRoleType))
auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingErrorsToAuditLog(t *testing.T) {
Expand Down Expand Up @@ -139,7 +142,8 @@ func TestWritingErrorsToAuditLog(t *testing.T) {
verifyAuditLogEntryResult(logLine, "-", dummyURLPath, t)
})

auditLogger.Log(request.LogRequest{Request: req, ARN: ""}, dummyResponseCode, GetCredentialsEventType(dummyRoleType))
auditLogger.Log(request.LogRequest{Request: req, ARN: ""}, dummyResponseCode,
auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestWritingToAuditLogWhenDisabled(t *testing.T) {
Expand All @@ -162,7 +166,8 @@ func TestWritingToAuditLogWhenDisabled(t *testing.T) {

mockInfoLogger.EXPECT().Info(gomock.Any()).Times(0)

auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode, GetCredentialsEventType(dummyRoleType))
auditLogger.Log(request.LogRequest{Request: req, ARN: taskARN}, dummyResponseCode,
auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType))
}

func TestConstructCommonAuditLogEntryFields(t *testing.T) {
Expand All @@ -181,7 +186,8 @@ func TestConstructCommonAuditLogEntryFields(t *testing.T) {
}

func TestConstructAuditLogEntryByTypeGetCredentials(t *testing.T) {
result := constructAuditLogEntryByType(GetCredentialsEventType(dummyRoleType), dummyCluster,
result := constructAuditLogEntryByType(
auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType), dummyCluster,
dummyContainerInstanceArn)
verifyConstructAuditLogEntryGetCredentialsResult(result, t)
}
Expand Down Expand Up @@ -209,7 +215,8 @@ func verifyConstructAuditLogEntryGetCredentialsResult(result string, t *testing.
tokens := strings.Split(result, " ")

assert.Equal(t, getCredentialsEntryFieldCount, len(tokens), "Incorrect number of tokens in GetCredentials audit log entry")
assert.Equal(t, GetCredentialsEventType(dummyRoleType), tokens[0], "event type does not match")
assert.Equal(t, auditinterface.GetCredentialsEventTypeFromRoleType(dummyRoleType),
tokens[0], "event type does not match")

auditLogVersion, _ := strconv.Atoi(tokens[1])
assert.Equal(t, getCredentialsAuditLogVersion, auditLogVersion, "version does not match")
Expand Down
21 changes: 3 additions & 18 deletions agent/logger/audit/entry_types.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,12 @@ import (
"time"

"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit"
"github.com/aws/amazon-ecs-agent/ecs-agent/logger/audit/request"
log "github.com/cihub/seelog"
)

const (
getCredentialsEventType = "GetCredentials"
getCredentialsTaskExecutionEventType = "GetCredentialsExecutionRole"
getCredentialsInvalidRoleTypeEventType = "GetCredentialsInvalidRoleType"

// getCredentialsAuditLogVersion is the version of the audit log
// Version '1', the fields are:
// 1. event time
Expand Down Expand Up @@ -56,18 +53,6 @@ type commonAuditLogEntryFields struct {
arn string
}

// GetCredentialsEventType is the type for a GetCredentials request
func GetCredentialsEventType(roleType string) string {
switch roleType {
case credentials.ApplicationRoleType:
return getCredentialsEventType
case credentials.ExecutionRoleType:
return getCredentialsTaskExecutionEventType
default:
return getCredentialsInvalidRoleTypeEventType
}
}

func (c *commonAuditLogEntryFields) string() string {
return fmt.Sprintf("%s %d %s %s %s %s", c.eventTime, c.responseCode, c.srcAddr, c.theURL, c.userAgent, c.arn)
}
Expand Down Expand Up @@ -103,15 +88,15 @@ func constructCommonAuditLogEntryFields(r request.LogRequest, httpResponseCode i

func constructAuditLogEntryByType(eventType string, cluster string, containerInstanceArn string) string {
switch eventType {
case getCredentialsEventType:
case audit.GetCredentialsEventType:
fields := &getCredentialsAuditLogEntryFields{
eventType: eventType,
version: getCredentialsAuditLogVersion,
cluster: populateField(cluster),
containerInstanceArn: populateField(containerInstanceArn),
}
return fields.string()
case getCredentialsTaskExecutionEventType:
case audit.GetCredentialsTaskExecutionEventType:
fields := &getCredentialsAuditLogEntryFields{
eventType: eventType,
version: getCredentialsAuditLogVersion,
Expand Down
25 changes: 2 additions & 23 deletions agent/utils/utils.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@ import (

"github.com/aws/amazon-ecs-agent/agent/ecs_client/model/ecs"
"github.com/aws/amazon-ecs-agent/agent/utils/httpproxy"
commonutils "github.com/aws/amazon-ecs-agent/ecs-agent/utils"
"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/aws/arn"
"github.com/aws/aws-sdk-go/aws/awserr"
Expand All @@ -44,28 +45,6 @@ func DefaultIfBlank(str string, default_value string) string {
return str
}

func ZeroOrNil(obj interface{}) bool {
value := reflect.ValueOf(obj)
if !value.IsValid() {
return true
}
if obj == nil {
return true
}
switch value.Kind() {
case reflect.Slice, reflect.Array, reflect.Map:
return value.Len() == 0
}
zero := reflect.Zero(reflect.TypeOf(obj))
if !value.Type().Comparable() {
return false
}
if obj == zero.Interface() {
return true
}
return false
}

// SlicesDeepEqual checks if slice1 and slice2 are equal, disregarding order.
func SlicesDeepEqual(slice1, slice2 interface{}) bool {
s1 := reflect.ValueOf(slice1)
Expand Down Expand Up @@ -216,7 +195,7 @@ func SearchStrInDir(dir, filePrefix, content string) error {
for _, file := range logfiles {
if strings.HasPrefix(file.Name(), filePrefix) {
desiredFile = file.Name()
if ZeroOrNil(desiredFile) {
if commonutils.ZeroOrNil(desiredFile) {
return fmt.Errorf("File with prefix: %v does not exist", filePrefix)
}

Expand Down
47 changes: 0 additions & 47 deletions agent/utils/utils_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
package utils

import (
"encoding/json"
"errors"
"sort"
"testing"
Expand All @@ -39,52 +38,6 @@ func TestDefaultIfBlank(t *testing.T) {
assert.Equal(t, defaultValue, result)
}

type dummyStruct struct {
// no contents
}

func (d dummyStruct) MarshalJSON([]byte, error) {
json.Marshal(nil)
}

func TestZeroOrNil(t *testing.T) {
type ZeroTest struct {
testInt int
TestStr string
testNilJson dummyStruct
}

var strMap map[string]string

testCases := []struct {
param interface{}
expected bool
name string
}{
{nil, true, "Nil is nil"},
{0, true, "0 is 0"},
{"", true, "\"\" is the string zerovalue"},
{ZeroTest{}, true, "ZeroTest zero-value should be zero"},
{ZeroTest{TestStr: "asdf"}, false, "ZeroTest with a field populated isn't zero"},
{ZeroTest{testNilJson: dummyStruct{}}, true, "nil is nil"},
{1, false, "1 is not 0"},
{[]uint16{1, 2, 3}, false, "[1,2,3] is not zero"},
{[]uint16{}, true, "[] is zero"},
{struct{ uncomparable []uint16 }{uncomparable: []uint16{1, 2, 3}}, false, "Uncomparable structs are never zero"},
{struct{ uncomparable []uint16 }{uncomparable: nil}, false, "Uncomparable structs are never zero"},
{strMap, true, "map[string]string is zero or nil"},
{make(map[string]string), true, "empty map[string]string is zero or nil"},
{map[string]string{"foo": "bar"}, false, "map[string]string{foo:bar} is not zero or nil"},
}

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
assert.Equal(t, tc.expected, ZeroOrNil(tc.param), tc.name)
})
}

}

func TestSlicesDeepEqual(t *testing.T) {
testCases := []struct {
left []string
Expand Down

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading