Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feature/ecs domainless gmsa #3735

Merged
merged 10 commits into from
Jun 5, 2023
19 changes: 16 additions & 3 deletions agent/acs/handler/refresh_credentials_handler.go
Original file line number Diff line number Diff line change
@@ -13,16 +13,22 @@
package handler

import (
"fmt"

"context"
"fmt"

"github.com/aws/amazon-ecs-agent/agent/engine"
"github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
"github.com/aws/amazon-ecs-agent/ecs-agent/wsclient"
"github.com/aws/aws-sdk-go/aws"
"github.com/cihub/seelog"

"github.com/pkg/errors"
)

var (
// For ease of unit testing
checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials
)

// refreshCredentialsHandler represents the refresh credentials operation for the ACS client
@@ -145,10 +151,11 @@ func (refreshHandler *refreshCredentialsHandler) handleSingleMessage(message *ec
if !validRoleType(roleType) {
seelog.Errorf("Unknown RoleType for task in credentials message, roleType: %s arn: %s, messageId: %s", roleType, taskArn, messageId)
} else {
iamRoleCredentials := credentials.IAMRoleCredentialsFromACS(message.RoleCredentials, roleType)
err = refreshHandler.credentialsManager.SetTaskCredentials(
&(credentials.TaskIAMRoleCredentials{
ARN: taskArn,
IAMRoleCredentials: credentials.IAMRoleCredentialsFromACS(message.RoleCredentials, roleType),
IAMRoleCredentials: iamRoleCredentials,
}))
if err != nil {
seelog.Errorf("Unable to update credentials for task, err: %v messageId: %s", err, messageId)
@@ -160,6 +167,12 @@ func (refreshHandler *refreshCredentialsHandler) handleSingleMessage(message *ec
}
if roleType == credentials.ExecutionRoleType {
task.SetExecutionRoleCredentialsID(aws.StringValue(message.RoleCredentials.CredentialsId))
// Refresh domainless gMSA plugin credentials if needed
err = checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl(iamRoleCredentials, task)
if err != nil {
seelog.Errorf("Unable to SetDomainlessGMSATaskExecutionRoleCredentials for task %s, err: %v messageId: %s", taskArn, err, messageId)
return errors.Wrap(err, "unable to SetDomainlessGMSATaskExecutionRoleCredentials")
}
}
}

45 changes: 45 additions & 0 deletions agent/acs/handler/refresh_credentials_handler_linux.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
//go:build linux
// +build linux

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"github.com/aws/amazon-ecs-agent/agent/api/task"
asmfactory "github.com/aws/amazon-ecs-agent/agent/asm/factory"
s3factory "github.com/aws/amazon-ecs-agent/agent/s3/factory"
ssmfactory "github.com/aws/amazon-ecs-agent/agent/ssm/factory"
"github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
)

func checkAndSetDomainlessGMSATaskExecutionRoleCredentials(iamRoleCredentials credentials.IAMRoleCredentials, task *task.Task) error {
// exit early if the task does not need domainless gMSA
if !task.RequiresDomainlessCredentialSpecResource() {
return nil
}
credspecContainerMapping := task.GetAllCredentialSpecRequirements()
credentialspecResource, err := credentialspec.NewCredentialSpecResource(task.Arn, "", task.ExecutionCredentialsID,
nil, ssmfactory.NewSSMClientCreator(), s3factory.NewS3ClientCreator(), asmfactory.NewClientCreator(), credspecContainerMapping)
if err != nil {
return err
}

err = credentialspecResource.HandleDomainlessKerberosTicketRenewal(iamRoleCredentials)
if err != nil {
return err
}
return nil
}
162 changes: 130 additions & 32 deletions agent/acs/handler/refresh_credentials_handler_test.go
Original file line number Diff line number Diff line change
@@ -17,11 +17,13 @@ package handler

import (
"context"
"fmt"
"reflect"
"sync"
"testing"
"time"

apicontainer "github.com/aws/amazon-ecs-agent/agent/api/container"
apitask "github.com/aws/amazon-ecs-agent/agent/api/task"
mock_engine "github.com/aws/amazon-ecs-agent/agent/engine/mocks"
"github.com/aws/amazon-ecs-agent/ecs-agent/acs/model/ecsacs"
@@ -30,6 +32,7 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/golang/mock/gomock"
"github.com/pkg/errors"
"github.com/stretchr/testify/assert"
)

@@ -234,49 +237,144 @@ func TestCredentialsMessageNotAckedWhenTaskNotFound(t *testing.T) {
}

// TestHandleRefreshMessageAckedWhenCredentialsUpdated tests that a credential message
// is ackd when the credentials are updated successfully
// is ackd when the credentials are updated successfully and the domainless gMSA plugin credentials are updated successfully
func TestHandleRefreshMessageAckedWhenCredentialsUpdated(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
credentialsManager := credentials.NewManager()
testCases := []struct {
name string
taskArn string
domainlessGMSATaskExpectedInput bool
containers []*apicontainer.Container
}{
{
name: "EmptyTaskSucceeds",
taskArn: taskArn,
containers: []*apicontainer.Container{},
},
}

ctx, cancel := context.WithCancel(context.Background())
var ackRequested *ecsacs.IAMRoleCredentialsAckRequest
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
credentialsManager := credentials.NewManager()

mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
ackRequested = ackRequest
cancel()
}).Times(1)
ctx, cancel := context.WithCancel(context.Background())
var ackRequested *ecsacs.IAMRoleCredentialsAckRequest

taskEngine := mock_engine.NewMockTaskEngine(ctrl)
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn(taskArn).Return(&apitask.Task{}, true)
mockWsClient := mock_wsclient.NewMockClientServer(ctrl)
mockWsClient.EXPECT().MakeRequest(gomock.Any()).Do(func(ackRequest *ecsacs.IAMRoleCredentialsAckRequest) {
ackRequested = ackRequest
cancel()
}).Times(1)

handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
go handler.sendAcks()
taskEngine := mock_engine.NewMockTaskEngine(ctrl)
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return(&apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true)

// test adding a credentials message without the MessageId field
err := handler.handleSingleMessage(message)
if err != nil {
t.Errorf("Error updating credentials: %v", err)
}
checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func(iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error {
if tc.taskArn != task.Arn {
return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, task.Arn))
}

// Wait till we get an ack from the ackBuffer
select {
case <-ctx.Done():
}
return nil
}

if !reflect.DeepEqual(ackRequested, expectedAck) {
t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested)
defer func() {
checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials
}()

handler := newRefreshCredentialsHandler(ctx, clusterName, containerInstanceArn, mockWsClient, credentialsManager, taskEngine)
go handler.sendAcks()

// test adding a credentials message without the MessageId field
err := handler.handleSingleMessage(message)
if err != nil {
t.Errorf("Error updating credentials: %v", err)
}

// Wait till we get an ack from the ackBuffer
select {
case <-ctx.Done():
}

if !reflect.DeepEqual(ackRequested, expectedAck) {
t.Errorf("Message between expected and requested ack. Expected: %v, Requested: %v", expectedAck, ackRequested)
}

creds, exist := credentialsManager.GetTaskCredentials(credentialsId)
if !exist {
t.Errorf("Expected credentials to exist for the task")
}
if !reflect.DeepEqual(creds, expectedCredentials) {
t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds)
}
})
}
}

creds, exist := credentialsManager.GetTaskCredentials(credentialsId)
if !exist {
t.Errorf("Expected credentials to exist for the task")
// TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsNotSet tests if credential messages
// are not acked when setting the domainlessGMSA Credentials fails
func TestCredentialsMessageNotAckedWhenDomainlessGMSACredentialsError(t *testing.T) {
testCases := []struct {
name string
taskArn string
containers []*apicontainer.Container
domainlessGMSATaskExpectedInput bool
setDomainlessGMSATaskExecutionRoleCredentialsImplError error
expectedErrorString string
}{
{
name: "ErrDomainlessTask",
taskArn: taskArn,
containers: []*apicontainer.Container{{CredentialSpecs: []string{"credentialspecdomainless:file://gmsa_gmsa-acct.json"}}},
domainlessGMSATaskExpectedInput: true,
setDomainlessGMSATaskExecutionRoleCredentialsImplError: errors.New("mock setDomainlessGMSATaskExecutionRoleCredentialsImplError"),
expectedErrorString: "unable to SetDomainlessGMSATaskExecutionRoleCredentials: mock setDomainlessGMSATaskExecutionRoleCredentialsImplError",
},
}
if !reflect.DeepEqual(creds, expectedCredentials) {
t.Errorf("Mismatch between expected credentials and credentials for task. Expected: %v, got: %v", expectedCredentials, creds)

for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
ctrl := gomock.NewController(t)
defer ctrl.Finish()
credentialsManager := credentials.NewManager()

taskEngine := mock_engine.NewMockTaskEngine(ctrl)
// Return a task from the engine for GetTaskByArn
taskEngine.EXPECT().GetTaskByArn(tc.taskArn).Return(&apitask.Task{Arn: tc.taskArn, Containers: tc.containers}, true)

checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = func(iamRoleCredentials credentials.IAMRoleCredentials, task *apitask.Task) error {
if tc.taskArn != task.Arn {
return errors.New(fmt.Sprintf("Expected taskArnInput to be %s, instead got %s", tc.taskArn, task.Arn))
}

return tc.setDomainlessGMSATaskExecutionRoleCredentialsImplError
}

defer func() {
checkAndSetDomainlessGMSATaskExecutionRoleCredentialsImpl = checkAndSetDomainlessGMSATaskExecutionRoleCredentials
}()

ctx, cancel := context.WithCancel(context.Background())
handler := newRefreshCredentialsHandler(ctx, cluster, containerInstance, nil, credentialsManager, taskEngine)

// Start a goroutine to listen for acks. Cancelling the context stops the goroutine
go func() {
for {
select {
// We never expect the message to be acked
case <-handler.ackRequest:
t.Fatalf("Received ack when none expected")
case <-ctx.Done():
return
}
}
}()

err := handler.handleSingleMessage(message)
assert.EqualError(t, err, tc.expectedErrorString)
cancel()
})
}
}

33 changes: 33 additions & 0 deletions agent/acs/handler/refresh_credentials_handler_windows.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
//go:build windows
// +build windows

// Copyright Amazon.com Inc. or its affiliates. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License"). You may
// not use this file except in compliance with the License. A copy of the
// License is located at
//
// http://aws.amazon.com/apache2.0/
//
// or in the "license" file accompanying this file. This file is distributed
// on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
// express or implied. See the License for the specific language governing
// permissions and limitations under the License.

package handler

import (
"github.com/aws/amazon-ecs-agent/agent/api/task"
"github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
)

// setDomainlessGMSATaskExecutionRoleCredentials sets the taskExecutionRoleCredentials to a Windows Registry Key so that
// the domainless gMSA plugin can use these credentials to retrieve the customer Active Directory credential
func checkAndSetDomainlessGMSATaskExecutionRoleCredentials(iamRoleCredentials credentials.IAMRoleCredentials, task *task.Task) error {
// exit early if the task does not need domainless gMSA
if !task.RequiresDomainlessCredentialSpecResource() {
return nil
}
return credentialspec.SetTaskExecutionCredentialsRegKeys(iamRoleCredentials, task.Arn)
}
Loading