Skip to content

Commit

Permalink
Inject domainless gmsa cred spec into Windows Container
Browse files Browse the repository at this point in the history
  • Loading branch information
arun-annamalai committed May 16, 2023
1 parent 62665fa commit f4687b5
Show file tree
Hide file tree
Showing 2 changed files with 1,195 additions and 272 deletions.
275 changes: 267 additions & 8 deletions agent/taskresource/credentialspec/credentialspec_windows.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ package credentialspec

import (
"crypto/sha256"
"encoding/json"
"fmt"
"os"
"path/filepath"
Expand All @@ -30,11 +31,14 @@ import (
s3factory "github.com/aws/amazon-ecs-agent/agent/s3/factory"
"github.com/aws/amazon-ecs-agent/agent/ssm"
ssmfactory "github.com/aws/amazon-ecs-agent/agent/ssm/factory"
"github.com/aws/amazon-ecs-agent/agent/utils"
"github.com/aws/amazon-ecs-agent/agent/utils/ioutilwrapper"
"github.com/aws/amazon-ecs-agent/agent/utils/oswrapper"
"github.com/aws/aws-sdk-go/aws/arn"

"github.com/cihub/seelog"
"github.com/pkg/errors"
"golang.org/x/sys/windows/registry"
)

const (
Expand All @@ -47,14 +51,38 @@ const (
// Environment variables to setup resource location
envProgramData = "ProgramData"
dockerCredentialSpecDataDir = "docker/credentialspecs"
ecsCcgPluginRegistryKeyRoot = `System\CurrentControlSet\Services\AmazonECSCCGPlugin`
regKeyPathFormat = `HKEY_LOCAL_MACHINE\` + ecsCcgPluginRegistryKeyRoot + `\%s`

credentialSpecParseErrorMsgTemplate = "Unable to parse %s from credential spec"
untypedMarshallErrorMsgTemplate = "Unable to marshal untyped object %s to type %s"
)

var (
// For ease of unit testing
osWriteFileImpl = os.WriteFile
osReadFileImpl = os.ReadFile
osRemoveImpl = os.Remove
readCredentialSpecImpl = readCredentialSpec
writeCredentialSpecImpl = writeCredentialSpec
readWriteDomainlessCredentialSpecImpl = readWriteDomainlessCredentialSpec
setTaskExecutionCredentialsRegKeysImpl = setTaskExecutionCredentialsRegKeys
handleNonFileDomainlessGMSACredSpecImpl = handleNonFileDomainlessGMSACredSpec
deleteTaskExecutionCredentialsRegKeysImpl = deleteTaskExecutionCredentialsRegKeys
)

type pluginInput struct {
CredentialArn string `json:"credentialArn,omitempty"`
RegKeyPath string `json:"regKeyPath,omitempty"`
}

// CredentialSpecResource is the abstraction for credentialspec resources
type CredentialSpecResource struct {
*CredentialSpecResourceCommon
ioutil ioutilwrapper.IOUtil
// credentialSpecResourceLocation is the location for all the tasks' credentialspec artifacts
credentialSpecResourceLocation string
isDomainlessGMSATask bool
}

// NewCredentialSpecResource creates a new CredentialSpecResource object
Expand All @@ -77,7 +105,8 @@ func NewCredentialSpecResource(taskARN, region string,
CredSpecMap: make(map[string]string),
credentialSpecContainerMap: credentialSpecContainerMap,
},
ioutil: ioutilwrapper.NewIOUtil(),
ioutil: ioutilwrapper.NewIOUtil(),
isDomainlessGMSATask: false,
}

err := s.setCredentialSpecResourceLocation()
Expand All @@ -100,11 +129,15 @@ func (cs *CredentialSpecResource) Create() error {
}

for credSpecStr := range cs.credentialSpecContainerMap {
credSpecSplit := strings.SplitAfterN(credSpecStr, "credentialspec:", 2)
credSpecSplit := strings.SplitAfterN(credSpecStr, ":", 2)
if len(credSpecSplit) != 2 {
seelog.Errorf("Invalid credentialspec: %s", credSpecStr)
continue
}
credSpecPrefix := credSpecSplit[0]
if credSpecPrefix == "credentialspecdomainless:" {
cs.isDomainlessGMSATask = true
}
credSpecValue := credSpecSplit[1]

if strings.HasPrefix(credSpecValue, "file://") {
Expand Down Expand Up @@ -145,22 +178,59 @@ func (cs *CredentialSpecResource) Create() error {
}
}

if cs.isDomainlessGMSATask {
// The domainless gMSA Windows Plugin needs the execution role credentials to pull customer secrets
err = setTaskExecutionCredentialsRegKeysImpl(iamCredentials, cs.CredentialSpecResourceCommon.taskARN)
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}

return nil
}

func (cs *CredentialSpecResource) handleCredentialspecFile(credentialspec string) error {
credSpecSplit := strings.SplitAfterN(credentialspec, "credentialspec:", 2)
credSpecSplit := strings.SplitAfterN(credentialspec, ":", 2)
if len(credSpecSplit) != 2 {
seelog.Errorf("Invalid credentialspec: %s", credentialspec)
return errors.New("invalid credentialspec file specification")
}
credSpecPrefix := credSpecSplit[0]
credSpecFile := credSpecSplit[1]

if !strings.HasPrefix(credSpecFile, "file://") {
return errors.New("invalid credentialspec file specification")
}

dockerHostconfigSecOptCredSpec := strings.Replace(credentialspec, "credentialspec:", "credentialspec=", 1)
if credSpecPrefix == "credentialspecdomainless:" {
relativeFilePath := strings.TrimPrefix(credSpecFile, "file://")
dir, originalFileName := filepath.Split(relativeFilePath)

// Generate unique filename using taskId, containerName, credspecfile original name
taskId, err := utils.TaskIdFromArn(cs.taskARN)
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
containerName, ok := cs.credentialSpecContainerMap[credentialspec]
if !ok {
return errors.New(fmt.Sprintf("Unable to retrieve containerName from credentialSpecContainerMap. No such key %s", credentialspec))
}

// We need a different outfile in order to avoid modifying the customers original credentialspec
outFile := fmt.Sprintf("%s_%s_%s", taskId, containerName, originalFileName)
credSpecFile = "file://" + filepath.Join(dir, outFile)

// Fill in appropriate domainless gMSA fields
err = readWriteDomainlessCredentialSpecImpl(filepath.Join(cs.credentialSpecResourceLocation, dir, originalFileName), filepath.Join(cs.credentialSpecResourceLocation, dir, outFile), cs.taskARN)
if err != nil {
cs.setTerminalReason(err.Error())
return err
}
}

dockerHostconfigSecOptCredSpec := "credentialspec=" + credSpecFile
cs.updateCredSpecMapping(credentialspec, dockerHostconfigSecOptCredSpec)

return nil
Expand Down Expand Up @@ -207,6 +277,12 @@ func (cs *CredentialSpecResource) handleS3CredentialspecFile(originalCredentials
return err
}

err = handleNonFileDomainlessGMSACredSpecImpl(originalCredentialspec, localCredSpecFilePath, cs.taskARN)
if err != nil {
cs.setTerminalReason(err.Error())
return err
}

dockerHostconfigSecOptCredSpec := fmt.Sprintf("credentialspec=file://%s", filepath.Base(localCredSpecFilePath))
cs.updateCredSpecMapping(originalCredentialspec, dockerHostconfigSecOptCredSpec)

Expand Down Expand Up @@ -269,6 +345,13 @@ func (cs *CredentialSpecResource) handleSSMCredentialspecFile(originalCredential
cs.setTerminalReason(err.Error())
return err
}

err = handleNonFileDomainlessGMSACredSpecImpl(originalCredentialspec, localCredSpecFilePath, cs.taskARN)
if err != nil {
cs.setTerminalReason(err.Error())
return err
}

dockerHostconfigSecOptCredSpec := fmt.Sprintf("credentialspec=file://%s", customCredSpecFileName)
cs.updateCredSpecMapping(originalCredentialspec, dockerHostconfigSecOptCredSpec)

Expand Down Expand Up @@ -317,11 +400,15 @@ func (cs *CredentialSpecResource) updateCredSpecMapping(credSpecInput, targetCre
// Cleanup removes the credentialspec created for the task
func (cs *CredentialSpecResource) Cleanup() error {
cs.clearCredentialSpec()
if cs.isDomainlessGMSATask {
err := cs.deleteTaskExecutionCredentialsRegKeys()
if err != nil {
return err
}
}
return nil
}

var remove = os.Remove

// clearCredentialSpec cycles through the collection of credentialspec data and
// removes them from the task
func (cs *CredentialSpecResource) clearCredentialSpec() {
Expand All @@ -334,14 +421,14 @@ func (cs *CredentialSpecResource) clearCredentialSpec() {
continue
}
// Split credentialspec to obtain local file-name
credSpecSplit := strings.SplitAfterN(value, "credentialspec=file://", 2)
credSpecSplit := strings.SplitAfterN(value, "file://", 2)
if len(credSpecSplit) != 2 {
seelog.Warnf("Unable to parse target credentialspec: %s", value)
continue
}
localCredentialSpecFile := credSpecSplit[1]
localCredentialSpecFilePath := filepath.Join(cs.credentialSpecResourceLocation, localCredentialSpecFile)
err := remove(localCredentialSpecFilePath)
err := osRemoveImpl(localCredentialSpecFilePath)
if err != nil {
seelog.Warnf("Unable to clear local credential spec file %s for task %s", localCredentialSpecFile, cs.taskARN)
}
Expand All @@ -350,6 +437,33 @@ func (cs *CredentialSpecResource) clearCredentialSpec() {
}
}

func (cs *CredentialSpecResource) deleteTaskExecutionCredentialsRegKeys() error {
cs.lock.Lock()
defer cs.lock.Unlock()

return deleteTaskExecutionCredentialsRegKeysImpl(cs.taskARN)
}

// deleteTaskExecutionCredentialsRegKeys deletes the taskExecutionRole IAM credentials in the task registry key
// after the task has been terminated.
func deleteTaskExecutionCredentialsRegKeys(taskARN string) error {
k, err := registry.OpenKey(registry.LOCAL_MACHINE, ecsCcgPluginRegistryKeyRoot, registry.ALL_ACCESS)
if err != nil {
// Early exit with success case, if the registry key doesn't exist then there are no task execution role creds to cleanup
seelog.Errorf("Error opening %s key: %s", ecsCcgPluginRegistryKeyRoot, err)
return nil
}
defer k.Close()

err = registry.DeleteKey(k, taskARN)
if err != nil {
seelog.Errorf("Error deleting %s key: %s", ecsCcgPluginRegistryKeyRoot+"\\"+taskARN, err)
return err
}
seelog.Infof("Deleted Task Execution Credential Registry key for task: %s", taskARN)
return nil
}

func (cs *CredentialSpecResource) setCredentialSpecResourceLocation() error {
// TODO: Use registry to setup credentialspec resource location
// This should always be available on Windows instances
Expand Down Expand Up @@ -378,3 +492,148 @@ func (cs *CredentialSpecResource) MarshallPlatformSpecificFields(credentialSpecR
func (cs *CredentialSpecResource) UnmarshallPlatformSpecificFields(credentialSpecResourceJSON CredentialSpecResourceJSON) {
return
}

// setTaskExecutionCredentialsRegKeys stores the taskExecutionRole IAM credentials to the task registry key
// so that the domainless gMSA plugin may use these credentials to access the customer Active Directory authentication
// information.
func setTaskExecutionCredentialsRegKeys(taskCredentials credentials.IAMRoleCredentials, taskArn string) error {
if taskCredentials == (credentials.IAMRoleCredentials{}) {
err := errors.New("Unable to find execution role credentials while setting registry key for task " + taskArn)
return err
}

taskRegistryKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, ecsCcgPluginRegistryKeyRoot+"\\"+taskArn, registry.WRITE)
if err != nil {
errMsg := fmt.Sprintf("Error creating registry key root %s for task %s: %s", ecsCcgPluginRegistryKeyRoot, taskArn, err)
seelog.Errorf(errMsg)
return errors.Wrapf(err, errMsg)
}
defer taskRegistryKey.Close()

err = taskRegistryKey.SetStringValue("AKID", taskCredentials.AccessKeyID)
if err != nil {
errMsg := fmt.Sprintf("Error creating AKID child value for task %s:%s", taskArn, err)
seelog.Errorf(errMsg)
return errors.Wrapf(err, errMsg)
}
err = taskRegistryKey.SetStringValue("SKID", taskCredentials.SecretAccessKey)
if err != nil {
errMsg := fmt.Sprintf("Error creating AKID child value for task %s:%s", taskArn, err)
seelog.Errorf(errMsg)
return errors.Wrapf(err, errMsg)
}
err = taskRegistryKey.SetStringValue("SESSIONTOKEN", taskCredentials.SessionToken)
if err != nil {
errMsg := fmt.Sprintf("Error creating SESSIONTOKEN child value for task %s:%s", taskArn, err)
seelog.Errorf(errMsg)
return errors.Wrapf(err, errMsg)
}

return nil
}

// handleNonFileDomainlessGMSACredSpec reads and then injects the taskExecutionRoleRegistryKey location for
// the s3/ssm gMSA credential spec cases.
func handleNonFileDomainlessGMSACredSpec(originalCredSpec, localCredSpecFilePath, taskARN string) error {
// Exit early for non domainless gMSA cred specs
if !strings.HasPrefix(originalCredSpec, "credentialspecdomainless:") {
return nil
}

err := readWriteDomainlessCredentialSpecImpl(localCredSpecFilePath, localCredSpecFilePath, taskARN)
if err != nil {
return err
}
return nil
}

// readWriteDomainlessCredentialSpec is used to open the credential spec file on local disk, inject the
// taskExecutionRoleInformation in memory, and then write the file to a specific path. The reason we do not
// modify the same fail is to avoid modifying the customer resource when the customer provides a local file
// credential spec
func readWriteDomainlessCredentialSpec(filePath, outFilePath, taskARN string) error {
credSpec, err := readCredentialSpecImpl(filePath)
if err != nil {
return err
}
err = writeCredentialSpecImpl(credSpec, outFilePath, taskARN)
if err != nil {
return err
}
return nil
}

// readCredentialSpec is used to open the credential spec file on local disk and read it into a generic
// bytes map object map[string]interface{}
func readCredentialSpec(filePath string) (map[string]interface{}, error) {
byteResult, err := osReadFileImpl(filePath)
if err != nil {
return nil, err
}
var credSpec map[string]interface{}
err = json.Unmarshal(byteResult, &credSpec)
if err != nil {
return nil, err
}
return credSpec, nil
}

// writeCredentialSpec is used to selectively decode portions of the Microsoft gMSA generated credential spec file and then
// inject the taskExecutionRoleRegistryKey location so that the gMSA plugin is able to access these IAM credentials.
// The reason that the JSON unmarshalling is manual is to protect against future key/value pairs that appear in the JSON,
// while only modifying the portions that pertain to domainless gMSA. This is in case Microsoft adds additional keys to the
// JSON credential spec, so that our writer does not ignore this data.
func writeCredentialSpec(credSpec map[string]interface{}, outFilePath string, taskARN string) error {
activeDirectoryConfigUntyped, ok := credSpec["ActiveDirectoryConfig"]
if !ok {
return errors.New(fmt.Sprintf(credentialSpecParseErrorMsgTemplate, "ActiveDirectoryConfig"))
}
activeDirectoryConfig, ok := activeDirectoryConfigUntyped.(map[string]interface{})
if !ok {
return errors.New(fmt.Sprintf(untypedMarshallErrorMsgTemplate, "activeDirectoryConfigUntyped", "map[string]interface{}"))
}

hostAccountConfigUntyped, ok := activeDirectoryConfig["HostAccountConfig"]
if !ok {
return errors.New(fmt.Sprintf(credentialSpecParseErrorMsgTemplate, "HostAccountConfig"))
}
hostAccountConfig, ok := hostAccountConfigUntyped.(map[string]interface{})
if !ok {
return errors.New(fmt.Sprintf(untypedMarshallErrorMsgTemplate, "hostAccountConfigUntyped", "map[string]interface{}"))
}

pluginInputStringUntyped, ok := hostAccountConfig["PluginInput"]
if !ok {
return errors.New(fmt.Sprintf(credentialSpecParseErrorMsgTemplate, "PluginInput"))
}
var pluginInputParsed pluginInput
pluginInputString, ok := pluginInputStringUntyped.(string)
if !ok {
return errors.New(fmt.Sprintf(untypedMarshallErrorMsgTemplate, "pluginInputStringUntyped", "string"))
}
err := json.Unmarshal([]byte(pluginInputString), &pluginInputParsed)
if err != nil {
return err
}

pluginInputParsed.RegKeyPath = fmt.Sprintf(regKeyPathFormat, taskARN)

pluginInputBytes, err := json.Marshal(pluginInputParsed)
if err != nil {
return err
}

hostAccountConfig["PluginInput"] = string(pluginInputBytes)

jsonBytes, err := json.Marshal(credSpec)
if err != nil {
return err
}

err = osWriteFileImpl(outFilePath, jsonBytes, filePerm)
if err != nil {
return err
}

return nil
}
Loading

0 comments on commit f4687b5

Please sign in to comment.