Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refresh handler gmsa linux #3709

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions agent/acs/handler/refresh_credentials_handler_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ package handler

import (
"github.com/aws/amazon-ecs-agent/agent/api/task"
asmfactory "github.com/aws/amazon-ecs-agent/agent/asm/factory"
s3factory "github.com/aws/amazon-ecs-agent/agent/s3/factory"
ssmfactory "github.com/aws/amazon-ecs-agent/agent/ssm/factory"
"github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
)

Expand All @@ -26,5 +30,16 @@ func checkAndSetDomainlessGMSATaskExecutionRoleCredentials(iamRoleCredentials cr
if !task.RequiresDomainlessCredentialSpecResource() {
return nil
}
credspecContainerMapping := task.GetAllCredentialSpecRequirements()
credentialspecResource, err := credentialspec.NewCredentialSpecResource(task.Arn, "", task.ExecutionCredentialsID,
amogh09 marked this conversation as resolved.
Show resolved Hide resolved
nil, ssmfactory.NewSSMClientCreator(), s3factory.NewS3ClientCreator(), asmfactory.NewClientCreator(), credspecContainerMapping)
if err != nil {
return err
}

err = credentialspecResource.HandleDomainlessKerberosTicketRenewal(iamRoleCredentials)
if err != nil {
return err
}
return nil
}
4 changes: 2 additions & 2 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config,
// initializeCredentialSpecResource builds the resource dependency map for the credentialspec resource
func (task *Task) initializeCredentialSpecResource(config *config.Config, credentialsManager credentials.Manager,
resourceFields *taskresource.ResourceFields) error {
credspecContainerMapping := task.getAllCredentialSpecRequirements()
credspecContainerMapping := task.GetAllCredentialSpecRequirements()
credentialspecResource, err := credentialspec.NewCredentialSpecResource(task.Arn, config.AWSRegion, task.ExecutionCredentialsID,
credentialsManager, resourceFields.SSMClientCreator, resourceFields.S3ClientCreator, resourceFields.ASMClientCreator, credspecContainerMapping)
if err != nil {
Expand Down Expand Up @@ -3019,7 +3019,7 @@ func (task *Task) GetCredentialSpecResource() ([]taskresource.TaskResource, bool
}

// getAllCredentialSpecRequirements is used to build all the credential spec requirements for the task
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

super nit: comment also needs to be update to GetAllCredentialSpecRequirements.

func (task *Task) getAllCredentialSpecRequirements() map[string]string {
func (task *Task) GetAllCredentialSpecRequirements() map[string]string {
reqsContainerMap := make(map[string]string)
for _, container := range task.Containers {
credentialSpec, err := container.GetCredentialSpec()
Expand Down
6 changes: 3 additions & 3 deletions agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4737,7 +4737,7 @@ func TestGetAllCredentialSpecRequirements(t *testing.T) {
Containers: []*apicontainer.Container{container},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json"
expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp1"}
Expand All @@ -4758,7 +4758,7 @@ func TestGetAllCredentialSpecRequirementsWithMultipleContainersUsingSameSpec(t *
Containers: []*apicontainer.Container{c1, c2},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json"
expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp2"}
Expand All @@ -4785,7 +4785,7 @@ func TestGetAllCredentialSpecRequirementsWithMultipleContainers(t *testing.T) {
Containers: []*apicontainer.Container{c1, c2, c3},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspec1 := "credentialspec:file://gmsa_gmsa-acct-1.json"
credentialspec2 := "credentialspec:file://gmsa_gmsa-acct-2.json"
Expand Down
162 changes: 118 additions & 44 deletions agent/taskresource/credentialspec/credentialspec_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type CredentialSpecResource struct {
*CredentialSpecResourceCommon
// This stores the identifier associated with the kerberos tickets created for the task
leaseID string
// identify domainless or domain-joined gMSA
isDomainlessGmsa bool
// This stores credspec arn and the corresponding service account name, domain name
// * key := credentialspec:ssmARN, value := corresponding ServiceAccountInfo
// * key := credentialspec:asmARN, value := corresponding ServiceAccountInfo
Expand Down Expand Up @@ -156,13 +158,51 @@ func (cs *CredentialSpecResource) Create() error {
iamCredentials = executionCredentials.GetIAMRoleCredentials()
}

isDomainlessGmsa := false
err := cs.retrieveCredentialSpecs(iamCredentials)
saikiranakula-amzn marked this conversation as resolved.
Show resolved Hide resolved
if err != nil {
return err
}

seelog.Infof("credentials fetcher daemon request: %v", cs.credentialsFetcherRequest)

// Check if skip credential fetcher invocation check override is present
skipSkipCredentialsFetcherInvocationCheck := utils.ParseBool(os.Getenv(envSkipCredentialsFetcherInvocation), false)
if skipSkipCredentialsFetcherInvocationCheck {
seelog.Info("Skipping credential fetcher invocation based on environment override")
testKrbFilePath := "/tmp/tgt"
os.Create(testKrbFilePath)
// assign temporary variable for test
cs.leaseID = "12345"
for k := range cs.ServiceAccountInfoMap {
cs.CredSpecMap[k] = testKrbFilePath
}
return nil
}

if cs.isDomainlessGmsa {
err := cs.handleDomainlessKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
} else {
err := cs.handleKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}
return nil
}

func (cs *CredentialSpecResource) retrieveCredentialSpecs(iamCredentials credentials.IAMRoleCredentials) error {
cs.isDomainlessGmsa = false
var wg sync.WaitGroup
errorEvents := make(chan error, len(cs.credentialSpecContainerMap))
for credSpecStr := range cs.credentialSpecContainerMap {
isDomainlessGmsa = strings.Contains(credSpecStr, "credentialspecdomainless")
cs.isDomainlessGmsa = strings.Contains(credSpecStr, "credentialspecdomainless")
var credSpecSplit []string
if isDomainlessGmsa {
if cs.isDomainlessGmsa {
credSpecSplit = strings.SplitAfterN(credSpecStr, "credentialspecdomainless:", 2)
} else {
credSpecSplit = strings.SplitAfterN(credSpecStr, "credentialspec:", 2)
Expand All @@ -176,7 +216,7 @@ func (cs *CredentialSpecResource) Create() error {
credSpecValue := credSpecSplit[1]
if strings.HasPrefix(credSpecValue, "file://") {
wg.Add(1)
go cs.handleCredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, &wg, errorEvents)
go cs.handleCredentialspecFile(credSpecStr, credSpecValue, &wg, errorEvents)
continue
}

Expand All @@ -189,10 +229,10 @@ func (cs *CredentialSpecResource) Create() error {
switch parsedARNService {
case "s3":
wg.Add(1)
go cs.handleS3CredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, iamCredentials, &wg, errorEvents)
go cs.handleS3CredentialspecFile(credSpecStr, credSpecValue, iamCredentials, &wg, errorEvents)
case "ssm":
wg.Add(1)
go cs.handleSSMCredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, iamCredentials, &wg, errorEvents)
go cs.handleSSMCredentialspecFile(credSpecStr, credSpecValue, iamCredentials, &wg, errorEvents)
default:
err = errors.New("unsupported credentialspec ARN, only s3/ssm ARNs are valid")
cs.setTerminalReason(err.Error())
Expand All @@ -211,36 +251,6 @@ func (cs *CredentialSpecResource) Create() error {
cs.setTerminalReason(errorString)
return errors.New(errorString)
}

seelog.Infof("credentials fetcher daemon request: %v", cs.credentialsFetcherRequest)

// Check if skip credential fetcher invocation check override is present
skipSkipCredentialsFetcherInvocationCheck := utils.ParseBool(os.Getenv(envSkipCredentialsFetcherInvocation), false)
if skipSkipCredentialsFetcherInvocationCheck {
seelog.Info("Skipping credential fetcher invocation based on environment override")
testKrbFilePath := "/tmp/tgt"
os.Create(testKrbFilePath)
// assign temporary variable for test
cs.leaseID = "12345"
for k := range cs.ServiceAccountInfoMap {
cs.CredSpecMap[k] = testKrbFilePath
}
return nil
}

if isDomainlessGmsa {
err := cs.handleDomainlessKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
} else {
err := cs.handleKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}
return nil
}

Expand Down Expand Up @@ -286,6 +296,58 @@ func (cs *CredentialSpecResource) handleDomainlessKerberosTicketCreation() error
}
return nil
}
func (cs *CredentialSpecResource) HandleDomainlessKerberosTicketRenewal(iamCredentials credentials.IAMRoleCredentials) error {
//update the region if it is not already set
err := cs.UpdateRegionFromTask()
if err != nil {
return err
}

err = cs.retrieveCredentialSpecs(iamCredentials)
if err != nil {
return err
}

visitedDomainlessUser := make(map[string]bool)
// Renew kerberos tickets for the gMSA service accounts in domain-less mode on the host location /var/credentials-fetcher/krbdir
for _, v := range cs.ServiceAccountInfoMap {
if v.domainlessGmsaUserArn != "" {
_, ok := visitedDomainlessUser[v.domainlessGmsaUserArn]
if !ok {
visitedDomainlessUser[v.domainlessGmsaUserArn] = true

// get domain-user credentials from secrets manager
asmClient := cs.secretsmanagerClientCreator.NewASMClient(cs.region, iamCredentials)

asmSecretData, err := asm.GetSecretFromASM(v.domainlessGmsaUserArn, asmClient)
if err != nil {
return fmt.Errorf("failed to retrieve credentials for domainless gMSA user %s: %w", v.domainlessGmsaUserArn, err)
}
creds := DomainlessUserCredentials{}
if err := json.Unmarshal([]byte(asmSecretData), &creds); err != nil {
return fmt.Errorf("failed to parse asmSecretData for the gMSA AD user: %w", err)
}
//set up server connection to communicate with credentials fetcher daemon
conn, err := credentialsfetcherclient.GetGrpcClientConnection()
if err != nil {
seelog.Errorf("failed to connect with credentials fetcher daemon: %s", err)
return err
}
seelog.Infof("grpc connection: %v", conn)

_, err = credentialsfetcherclient.NewCredentialsFetcherClient(conn, time.Minute).RenewNonDomainJoinedKerberosLease(context.Background(),
creds.Username, creds.Password, creds.DomainName)

if err != nil {
cs.setTerminalReason(err.Error())
return fmt.Errorf("failed to renew kerberos tickets associated service account %s: %w", v.domainlessGmsaUserArn, err)
}
seelog.Infof("renewal is successful: %v", cs.leaseID)
}
}
}
return nil
}

func (cs *CredentialSpecResource) handleKerberosTicketCreation() error {
// Create kerberos tickets for the gMSA service accounts on the host location /var/credentials-fetcher/krbdir
Expand Down Expand Up @@ -323,7 +385,7 @@ func (cs *CredentialSpecResource) handleKerberosTicketCreation() error {
return nil
}

func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpecFile, credentialSpec string, isDomainlessGmsa bool, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpecFile, credentialSpec string, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()

if !strings.HasPrefix(credentialSpec, "file://") {
Expand All @@ -343,15 +405,15 @@ func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpe

credSpecData := string(data)

err = cs.updateCredSpecMapping(originalCredentialSpecFile, credSpecData, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecFile, credSpecData)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
return
}
}

func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialSpecARN, credentialSpecS3ARN string, isDomainlessGmsa bool, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialSpecARN, credentialSpecS3ARN string, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()
if iamCredentials == (credentials.IAMRoleCredentials{}) {
err := errors.New("credentialspec resource: unable to find execution role credentials")
Expand Down Expand Up @@ -387,15 +449,15 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS
json.Compact(credSpecJsonStringBytes, []byte(credSpecJsonStringUnformatted))
credSpecJsonString := credSpecJsonStringBytes.String()

err = cs.updateCredSpecMapping(originalCredentialSpecARN, credSpecJsonString, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecARN, credSpecJsonString)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
return
}
}

func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredentialSpecARN, credentialSpecSSMARN string, isDomainlessGmsa bool, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredentialSpecARN, credentialSpecSSMARN string, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()

if iamCredentials == (credentials.IAMRoleCredentials{}) {
Expand Down Expand Up @@ -433,7 +495,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
}

ssmParamData := ssmParamMap[ssmParam[1]]
err = cs.updateCredSpecMapping(originalCredentialSpecARN, ssmParamData, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecARN, ssmParamData)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
Expand All @@ -442,15 +504,15 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
}

// updateCredSpecMapping updates the mapping of credentialSpec input and the corresponding service account info(serviceAccountName, DomainNAme)
func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecContent string, isDomainlessGmsa bool) error {
func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecContent string) error {
cs.lock.Lock()
defer cs.lock.Unlock()

var serviceAccountName string
var domainName string
var domainlessGmsaUserArn string
//parse json to extract the service account name and the domain name
if isDomainlessGmsa {
if cs.isDomainlessGmsa {
var credentialSpecDomainlessSchema CredentialSpecDomainlessSchema
// Unmarshal or Decode the JSON to the interface for domainless gmsa.
err := json.Unmarshal([]byte(credSpecContent), &credentialSpecDomainlessSchema)
Expand Down Expand Up @@ -491,6 +553,18 @@ func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecC
return nil
}

// update region if is not set
func (cs *CredentialSpecResource) UpdateRegionFromTask() error {
// Parse taskARN
parsedARN, err := arn.Parse(cs.taskARN)
if err != nil {
return err
}

cs.region = parsedARN.Region
return nil
}

// Cleanup removes the credentialSpec created for the task
func (cs *CredentialSpecResource) Cleanup() error {
cs.clearKerberosTickets()
Expand Down
Loading