Skip to content

Commit

Permalink
Merge pull request #3709 from saikiranakula-amzn/refresh_handler_gmsa…
Browse files Browse the repository at this point in the history
…_linux

Refresh handler gmsa linux
  • Loading branch information
saikiranakula-amzn authored May 25, 2023
2 parents 4600042 + 11b210d commit 446970c
Show file tree
Hide file tree
Showing 5 changed files with 183 additions and 63 deletions.
15 changes: 15 additions & 0 deletions agent/acs/handler/refresh_credentials_handler_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,10 @@ package handler

import (
"github.com/aws/amazon-ecs-agent/agent/api/task"
asmfactory "github.com/aws/amazon-ecs-agent/agent/asm/factory"
s3factory "github.com/aws/amazon-ecs-agent/agent/s3/factory"
ssmfactory "github.com/aws/amazon-ecs-agent/agent/ssm/factory"
"github.com/aws/amazon-ecs-agent/agent/taskresource/credentialspec"
"github.com/aws/amazon-ecs-agent/ecs-agent/credentials"
)

Expand All @@ -26,5 +30,16 @@ func checkAndSetDomainlessGMSATaskExecutionRoleCredentials(iamRoleCredentials cr
if !task.RequiresDomainlessCredentialSpecResource() {
return nil
}
credspecContainerMapping := task.GetAllCredentialSpecRequirements()
credentialspecResource, err := credentialspec.NewCredentialSpecResource(task.Arn, "", task.ExecutionCredentialsID,
nil, ssmfactory.NewSSMClientCreator(), s3factory.NewS3ClientCreator(), asmfactory.NewClientCreator(), credspecContainerMapping)
if err != nil {
return err
}

err = credentialspecResource.HandleDomainlessKerberosTicketRenewal(iamRoleCredentials)
if err != nil {
return err
}
return nil
}
4 changes: 2 additions & 2 deletions agent/api/task/task.go
Original file line number Diff line number Diff line change
Expand Up @@ -469,7 +469,7 @@ func (task *Task) PostUnmarshalTask(cfg *config.Config,
// initializeCredentialSpecResource builds the resource dependency map for the credentialspec resource
func (task *Task) initializeCredentialSpecResource(config *config.Config, credentialsManager credentials.Manager,
resourceFields *taskresource.ResourceFields) error {
credspecContainerMapping := task.getAllCredentialSpecRequirements()
credspecContainerMapping := task.GetAllCredentialSpecRequirements()
credentialspecResource, err := credentialspec.NewCredentialSpecResource(task.Arn, config.AWSRegion, task.ExecutionCredentialsID,
credentialsManager, resourceFields.SSMClientCreator, resourceFields.S3ClientCreator, resourceFields.ASMClientCreator, credspecContainerMapping)
if err != nil {
Expand Down Expand Up @@ -3019,7 +3019,7 @@ func (task *Task) GetCredentialSpecResource() ([]taskresource.TaskResource, bool
}

// getAllCredentialSpecRequirements is used to build all the credential spec requirements for the task
func (task *Task) getAllCredentialSpecRequirements() map[string]string {
func (task *Task) GetAllCredentialSpecRequirements() map[string]string {
reqsContainerMap := make(map[string]string)
for _, container := range task.Containers {
credentialSpec, err := container.GetCredentialSpec()
Expand Down
6 changes: 3 additions & 3 deletions agent/api/task/task_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4737,7 +4737,7 @@ func TestGetAllCredentialSpecRequirements(t *testing.T) {
Containers: []*apicontainer.Container{container},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json"
expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp1"}
Expand All @@ -4758,7 +4758,7 @@ func TestGetAllCredentialSpecRequirementsWithMultipleContainersUsingSameSpec(t *
Containers: []*apicontainer.Container{c1, c2},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspecFileLocation := "credentialspec:file://gmsa_gmsa-acct.json"
expectedCredentialSpecContainerMap := map[string]string{credentialspecFileLocation: "webapp2"}
Expand All @@ -4785,7 +4785,7 @@ func TestGetAllCredentialSpecRequirementsWithMultipleContainers(t *testing.T) {
Containers: []*apicontainer.Container{c1, c2, c3},
}

credentialSpecContainerMap := task.getAllCredentialSpecRequirements()
credentialSpecContainerMap := task.GetAllCredentialSpecRequirements()

credentialspec1 := "credentialspec:file://gmsa_gmsa-acct-1.json"
credentialspec2 := "credentialspec:file://gmsa_gmsa-acct-2.json"
Expand Down
162 changes: 118 additions & 44 deletions agent/taskresource/credentialspec/credentialspec_linux.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,8 @@ type CredentialSpecResource struct {
*CredentialSpecResourceCommon
// This stores the identifier associated with the kerberos tickets created for the task
leaseID string
// identify domainless or domain-joined gMSA
isDomainlessGmsa bool
// This stores credspec arn and the corresponding service account name, domain name
// * key := credentialspec:ssmARN, value := corresponding ServiceAccountInfo
// * key := credentialspec:asmARN, value := corresponding ServiceAccountInfo
Expand Down Expand Up @@ -156,13 +158,51 @@ func (cs *CredentialSpecResource) Create() error {
iamCredentials = executionCredentials.GetIAMRoleCredentials()
}

isDomainlessGmsa := false
err := cs.retrieveCredentialSpecs(iamCredentials)
if err != nil {
return err
}

seelog.Infof("credentials fetcher daemon request: %v", cs.credentialsFetcherRequest)

// Check if skip credential fetcher invocation check override is present
skipSkipCredentialsFetcherInvocationCheck := utils.ParseBool(os.Getenv(envSkipCredentialsFetcherInvocation), false)
if skipSkipCredentialsFetcherInvocationCheck {
seelog.Info("Skipping credential fetcher invocation based on environment override")
testKrbFilePath := "/tmp/tgt"
os.Create(testKrbFilePath)
// assign temporary variable for test
cs.leaseID = "12345"
for k := range cs.ServiceAccountInfoMap {
cs.CredSpecMap[k] = testKrbFilePath
}
return nil
}

if cs.isDomainlessGmsa {
err := cs.handleDomainlessKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
} else {
err := cs.handleKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}
return nil
}

func (cs *CredentialSpecResource) retrieveCredentialSpecs(iamCredentials credentials.IAMRoleCredentials) error {
cs.isDomainlessGmsa = false
var wg sync.WaitGroup
errorEvents := make(chan error, len(cs.credentialSpecContainerMap))
for credSpecStr := range cs.credentialSpecContainerMap {
isDomainlessGmsa = strings.Contains(credSpecStr, "credentialspecdomainless")
cs.isDomainlessGmsa = strings.Contains(credSpecStr, "credentialspecdomainless")
var credSpecSplit []string
if isDomainlessGmsa {
if cs.isDomainlessGmsa {
credSpecSplit = strings.SplitAfterN(credSpecStr, "credentialspecdomainless:", 2)
} else {
credSpecSplit = strings.SplitAfterN(credSpecStr, "credentialspec:", 2)
Expand All @@ -176,7 +216,7 @@ func (cs *CredentialSpecResource) Create() error {
credSpecValue := credSpecSplit[1]
if strings.HasPrefix(credSpecValue, "file://") {
wg.Add(1)
go cs.handleCredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, &wg, errorEvents)
go cs.handleCredentialspecFile(credSpecStr, credSpecValue, &wg, errorEvents)
continue
}

Expand All @@ -189,10 +229,10 @@ func (cs *CredentialSpecResource) Create() error {
switch parsedARNService {
case "s3":
wg.Add(1)
go cs.handleS3CredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, iamCredentials, &wg, errorEvents)
go cs.handleS3CredentialspecFile(credSpecStr, credSpecValue, iamCredentials, &wg, errorEvents)
case "ssm":
wg.Add(1)
go cs.handleSSMCredentialspecFile(credSpecStr, credSpecValue, isDomainlessGmsa, iamCredentials, &wg, errorEvents)
go cs.handleSSMCredentialspecFile(credSpecStr, credSpecValue, iamCredentials, &wg, errorEvents)
default:
err = errors.New("unsupported credentialspec ARN, only s3/ssm ARNs are valid")
cs.setTerminalReason(err.Error())
Expand All @@ -211,36 +251,6 @@ func (cs *CredentialSpecResource) Create() error {
cs.setTerminalReason(errorString)
return errors.New(errorString)
}

seelog.Infof("credentials fetcher daemon request: %v", cs.credentialsFetcherRequest)

// Check if skip credential fetcher invocation check override is present
skipSkipCredentialsFetcherInvocationCheck := utils.ParseBool(os.Getenv(envSkipCredentialsFetcherInvocation), false)
if skipSkipCredentialsFetcherInvocationCheck {
seelog.Info("Skipping credential fetcher invocation based on environment override")
testKrbFilePath := "/tmp/tgt"
os.Create(testKrbFilePath)
// assign temporary variable for test
cs.leaseID = "12345"
for k := range cs.ServiceAccountInfoMap {
cs.CredSpecMap[k] = testKrbFilePath
}
return nil
}

if isDomainlessGmsa {
err := cs.handleDomainlessKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
} else {
err := cs.handleKerberosTicketCreation()
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}
return nil
}

Expand Down Expand Up @@ -286,6 +296,58 @@ func (cs *CredentialSpecResource) handleDomainlessKerberosTicketCreation() error
}
return nil
}
func (cs *CredentialSpecResource) HandleDomainlessKerberosTicketRenewal(iamCredentials credentials.IAMRoleCredentials) error {
//update the region if it is not already set
err := cs.UpdateRegionFromTask()
if err != nil {
return err
}

err = cs.retrieveCredentialSpecs(iamCredentials)
if err != nil {
return err
}

visitedDomainlessUser := make(map[string]bool)
// Renew kerberos tickets for the gMSA service accounts in domain-less mode on the host location /var/credentials-fetcher/krbdir
for _, v := range cs.ServiceAccountInfoMap {
if v.domainlessGmsaUserArn != "" {
_, ok := visitedDomainlessUser[v.domainlessGmsaUserArn]
if !ok {
visitedDomainlessUser[v.domainlessGmsaUserArn] = true

// get domain-user credentials from secrets manager
asmClient := cs.secretsmanagerClientCreator.NewASMClient(cs.region, iamCredentials)

asmSecretData, err := asm.GetSecretFromASM(v.domainlessGmsaUserArn, asmClient)
if err != nil {
return fmt.Errorf("failed to retrieve credentials for domainless gMSA user %s: %w", v.domainlessGmsaUserArn, err)
}
creds := DomainlessUserCredentials{}
if err := json.Unmarshal([]byte(asmSecretData), &creds); err != nil {
return fmt.Errorf("failed to parse asmSecretData for the gMSA AD user: %w", err)
}
//set up server connection to communicate with credentials fetcher daemon
conn, err := credentialsfetcherclient.GetGrpcClientConnection()
if err != nil {
seelog.Errorf("failed to connect with credentials fetcher daemon: %s", err)
return err
}
seelog.Infof("grpc connection: %v", conn)

_, err = credentialsfetcherclient.NewCredentialsFetcherClient(conn, time.Minute).RenewNonDomainJoinedKerberosLease(context.Background(),
creds.Username, creds.Password, creds.DomainName)

if err != nil {
cs.setTerminalReason(err.Error())
return fmt.Errorf("failed to renew kerberos tickets associated service account %s: %w", v.domainlessGmsaUserArn, err)
}
seelog.Infof("renewal is successful: %v", cs.leaseID)
}
}
}
return nil
}

func (cs *CredentialSpecResource) handleKerberosTicketCreation() error {
// Create kerberos tickets for the gMSA service accounts on the host location /var/credentials-fetcher/krbdir
Expand Down Expand Up @@ -323,7 +385,7 @@ func (cs *CredentialSpecResource) handleKerberosTicketCreation() error {
return nil
}

func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpecFile, credentialSpec string, isDomainlessGmsa bool, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpecFile, credentialSpec string, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()

if !strings.HasPrefix(credentialSpec, "file://") {
Expand All @@ -343,15 +405,15 @@ func (cs *CredentialSpecResource) handleCredentialspecFile(originalCredentialSpe

credSpecData := string(data)

err = cs.updateCredSpecMapping(originalCredentialSpecFile, credSpecData, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecFile, credSpecData)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
return
}
}

func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialSpecARN, credentialSpecS3ARN string, isDomainlessGmsa bool, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialSpecARN, credentialSpecS3ARN string, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()
if iamCredentials == (credentials.IAMRoleCredentials{}) {
err := errors.New("credentialspec resource: unable to find execution role credentials")
Expand Down Expand Up @@ -387,15 +449,15 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentialS
json.Compact(credSpecJsonStringBytes, []byte(credSpecJsonStringUnformatted))
credSpecJsonString := credSpecJsonStringBytes.String()

err = cs.updateCredSpecMapping(originalCredentialSpecARN, credSpecJsonString, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecARN, credSpecJsonString)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
return
}
}

func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredentialSpecARN, credentialSpecSSMARN string, isDomainlessGmsa bool, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredentialSpecARN, credentialSpecSSMARN string, iamCredentials credentials.IAMRoleCredentials, wg *sync.WaitGroup, errorEvents chan error) {
defer wg.Done()

if iamCredentials == (credentials.IAMRoleCredentials{}) {
Expand Down Expand Up @@ -433,7 +495,7 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
}

ssmParamData := ssmParamMap[ssmParam[1]]
err = cs.updateCredSpecMapping(originalCredentialSpecARN, ssmParamData, isDomainlessGmsa)
err = cs.updateCredSpecMapping(originalCredentialSpecARN, ssmParamData)
if err != nil {
cs.setTerminalReason(err.Error())
errorEvents <- err
Expand All @@ -442,15 +504,15 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
}

// updateCredSpecMapping updates the mapping of credentialSpec input and the corresponding service account info(serviceAccountName, DomainNAme)
func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecContent string, isDomainlessGmsa bool) error {
func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecContent string) error {
cs.lock.Lock()
defer cs.lock.Unlock()

var serviceAccountName string
var domainName string
var domainlessGmsaUserArn string
//parse json to extract the service account name and the domain name
if isDomainlessGmsa {
if cs.isDomainlessGmsa {
var credentialSpecDomainlessSchema CredentialSpecDomainlessSchema
// Unmarshal or Decode the JSON to the interface for domainless gmsa.
err := json.Unmarshal([]byte(credSpecContent), &credentialSpecDomainlessSchema)
Expand Down Expand Up @@ -491,6 +553,18 @@ func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, credSpecC
return nil
}

// update region if is not set
func (cs *CredentialSpecResource) UpdateRegionFromTask() error {
// Parse taskARN
parsedARN, err := arn.Parse(cs.taskARN)
if err != nil {
return err
}

cs.region = parsedARN.Region
return nil
}

// Cleanup removes the credentialSpec created for the task
func (cs *CredentialSpecResource) Cleanup() error {
cs.clearKerberosTickets()
Expand Down
Loading

0 comments on commit 446970c

Please sign in to comment.