Skip to content

Commit

Permalink
Add support for cross account management of static roles in AWS Secre…
Browse files Browse the repository at this point in the history
…ts (#29645)

* aws-secrets/add-cross-acc-mgmt-static-roles

* refactor

* add function pointer for tests

* delete commented out code

* update

* update comment

* update func name

* add flag

* remove docs
  • Loading branch information
Zlaticanin authored Feb 14, 2025
1 parent 64e92ba commit 6e0c771
Show file tree
Hide file tree
Showing 12 changed files with 177 additions and 50 deletions.
23 changes: 20 additions & 3 deletions builtin/logical/aws/backend.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ import (

"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/aws/aws-sdk-go/service/sts/stsiface"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
"github.com/hashicorp/vault/sdk/queue"
Expand Down Expand Up @@ -87,6 +88,10 @@ func Backend(_ *logical.BackendConfig) *backend {
type backend struct {
*framework.Backend

// Function pointer used to override the IAM client creation for mocked testing
// If set, this function will be called instead of creating real IAM clients
nonCachedClientIAMFunc func(context.Context, logical.Storage, hclog.Logger, *staticRoleEntry) (iamiface.IAMAPI, error)

// Mutex to protect access to reading and writing policies
roleMutex sync.RWMutex

Expand Down Expand Up @@ -131,8 +136,9 @@ func (b *backend) clearClients() {
}

// clientIAM returns the configured IAM client. If nil, it constructs a new one
// and returns it, setting it the internal variable
func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IAMAPI, error) {
// and returns it, setting it the internal variable.
// entry is only needed when configuring the client to use for role assumption.
func (b *backend) clientIAM(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
b.clientMutex.RLock()
if b.iamClient != nil {
b.clientMutex.RUnlock()
Expand All @@ -150,10 +156,11 @@ func (b *backend) clientIAM(ctx context.Context, s logical.Storage) (iamiface.IA
return b.iamClient, nil
}

iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger())
iamClient, err := b.nonCachedClientIAM(ctx, s, b.Logger(), entry)
if err != nil {
return nil, err
}

b.iamClient = iamClient

return b.iamClient, nil
Expand Down Expand Up @@ -248,3 +255,13 @@ func (b *backend) initialize(ctx context.Context, request *logical.Initializatio
}
return nil
}

// getNonCachedIAMClient returns an IAM client. In a test env, if a mocked client creation
// function is set (nonCachedClientIAMFunc), it will be used instead of the default client creation function.
// This allows us to mock AWS clients in tests.
func (b *backend) getNonCachedIAMClient(ctx context.Context, storage logical.Storage, cfg staticRoleEntry) (iamiface.IAMAPI, error) {
if b.nonCachedClientIAMFunc != nil {
return b.nonCachedClientIAMFunc(ctx, storage, b.Logger(), &cfg)
}
return b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg)
}
30 changes: 21 additions & 9 deletions builtin/logical/aws/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -148,21 +148,33 @@ func (b *backend) getRootConfigs(ctx context.Context, s logical.Storage, clientT
return configs, nil
}

func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger) (*iam.IAM, error) {
awsConfig, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
if len(awsConfig) != 1 {
return nil, errors.New("could not obtain aws config")
func (b *backend) nonCachedClientIAM(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (*iam.IAM, error) {
var awsConfig *aws.Config
var err error

if entry != nil && entry.AssumeRoleARN != "" {
awsConfig, err = b.assumeRoleStatic(ctx, s, entry)
if err != nil {
return nil, fmt.Errorf("failed to assume role %q: %w", entry.AssumeRoleARN, err)
}
} else {
configs, err := b.getRootConfigs(ctx, s, "iam", logger)
if err != nil {
return nil, err
}
if len(configs) != 1 {
return nil, errors.New("could not obtain aws config")
}
awsConfig = configs[0]
}
sess, err := session.NewSession(awsConfig[0])

sess, err := session.NewSession(awsConfig)
if err != nil {
return nil, err
}
client := iam.New(sess)
if client == nil {
return nil, fmt.Errorf("could not obtain iam client")
return nil, fmt.Errorf("could not obtain IAM client")
}
return client, nil
}
Expand Down
21 changes: 21 additions & 0 deletions builtin/logical/aws/client_ce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build !enterprise

package aws

import (
"context"
"fmt"

"github.com/aws/aws-sdk-go/aws"
"github.com/hashicorp/vault/sdk/logical"
)

// assumeRoleStatic assumes an AWS role for cross-account static role management.
// It uses the role ARN and session name provided in the staticRoleEntry configuration
// to generate credentials for the assumed role.
func (b *backend) assumeRoleStatic(ctx context.Context, s logical.Storage, entry *staticRoleEntry) (*aws.Config, error) {
return nil, fmt.Errorf("cross-account static roles are only supported in Vault Enterprise")
}
2 changes: 1 addition & 1 deletion builtin/logical/aws/iam_policies.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ func (b *backend) getGroupPolicies(ctx context.Context, s logical.Storage, iamGr
return nil, nil, nil
}

iamClient, err = b.clientIAM(ctx, s)
iamClient, err = b.clientIAM(ctx, s, nil)
if err != nil {
return nil, nil, err
}
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/aws/path_config_rotate_root.go
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ func (b *backend) pathConfigRotateRootUpdate(ctx context.Context, req *logical.R

func (b *backend) rotateRoot(ctx context.Context, req *logical.Request) (*logical.Response, error) {
// have to get the client config first because that takes out a read lock
client, err := b.clientIAM(ctx, req.Storage)
client, err := b.clientIAM(ctx, req.Storage, nil)
if err != nil {
return nil, err
}
Expand Down
50 changes: 27 additions & 23 deletions builtin/logical/aws/path_static_roles.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,16 +21,22 @@ import (
const (
pathStaticRole = "static-roles"

paramRoleName = "name"
paramUsername = "username"
paramRotationPeriod = "rotation_period"
paramRoleName = "name"
paramUsername = "username"
paramRotationPeriod = "rotation_period"
paramAssumeRoleARN = "assume_role_arn"
paramRoleSessionName = "assume_role_session_name"
paramExternalID = "external_id"
)

type staticRoleEntry struct {
Name string `json:"name" structs:"name" mapstructure:"name"`
ID string `json:"id" structs:"id" mapstructure:"id"`
Username string `json:"username" structs:"username" mapstructure:"username"`
RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"`
Name string `json:"name" structs:"name" mapstructure:"name"`
ID string `json:"id" structs:"id" mapstructure:"id"`
Username string `json:"username" structs:"username" mapstructure:"username"`
RotationPeriod time.Duration `json:"rotation_period" structs:"rotation_period" mapstructure:"rotation_period"`
AssumeRoleARN string `json:"assume_role_arn" structs:"assume_role_arn" mapstructure:"assume_role_arn"`
AssumeRoleSessionName string `json:"assume_role_session_name" structs:"assume_role_session_name" mapstructure:"assume_role_session_name"`
ExternalID string `json:"external_id" structs:"external_id" mapstructure:"external_id"`
}

func pathStaticRoles(b *backend) *framework.Path {
Expand All @@ -53,23 +59,12 @@ func pathStaticRoles(b *backend) *framework.Path {
},
}},
}
fields := roleResponse[http.StatusOK][0].Fields
AddStaticAssumeRoleFieldsEnt(fields)

return &framework.Path{
Pattern: fmt.Sprintf("%s/%s", pathStaticRole, framework.GenericNameWithAtRegex(paramRoleName)),
Fields: map[string]*framework.FieldSchema{
paramRoleName: {
Type: framework.TypeString,
Description: descRoleName,
},
paramUsername: {
Type: framework.TypeString,
Description: descUsername,
},
paramRotationPeriod: {
Type: framework.TypeDurationSecond,
Description: descRotationPeriod,
},
},
Fields: fields,

Operations: map[logical.Operation]framework.OperationHandler{
logical.ReadOperation: &framework.PathOperation{
Expand Down Expand Up @@ -159,6 +154,11 @@ func (b *backend) pathStaticRolesWrite(ctx context.Context, req *logical.Request

// other params are optional if we're not Creating

err = validateAssumeRoleFields(data, &config)
if err != nil {
return nil, err
}

if rawUsername, ok := data.GetOk(paramUsername); ok {
config.Username = rawUsername.(string)

Expand Down Expand Up @@ -299,10 +299,11 @@ func (b *backend) validateRoleName(name string) error {
// validateIAMUser checks the user information we have for the role against the information on AWS. On a create, it uses the username
// to retrieve the user information and _sets_ the userID. On update, it validates the userID and username.
func (b *backend) validateIAMUserExists(ctx context.Context, storage logical.Storage, entry *staticRoleEntry, isCreate bool) error {
c, err := b.clientIAM(ctx, storage)
c, err := b.getNonCachedIAMClient(ctx, storage, *entry)
if err != nil {
return fmt.Errorf("unable to validate username %q: %w", entry.Username, err)
return fmt.Errorf("unable to get client to validate username %q: %w", entry.Username, err)
}
b.iamClient = c

// we don't really care about the content of the result, just that it's not an error
out, err := c.GetUser(&iam.GetUserInput{
Expand Down Expand Up @@ -364,4 +365,7 @@ const (
descUsername = "The IAM user to adopt as a static role."
descRotationPeriod = `Period by which to rotate the backing credential of the adopted user.
This can be a Go duration (e.g, '1m', 24h'), or an integer number of seconds.`
descAssumeRoleARN = `The AWS ARN for the role to be assumed when interacting with the account specified.`
descRoleSessionName = `An identifier for the assumed role session.`
descExternalID = `An external ID to be passed to the assumed role session.`
)
29 changes: 29 additions & 0 deletions builtin/logical/aws/path_static_roles_ce.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
// Copyright (c) HashiCorp, Inc.
// SPDX-License-Identifier: BUSL-1.1

//go:build !enterprise

package aws

import (
"fmt"

"github.com/hashicorp/vault/sdk/framework"
)

// AddStaticAssumeRoleFieldsEnt is a no-op for community edition
func AddStaticAssumeRoleFieldsEnt(fields map[string]*framework.FieldSchema) {
// no-op
}

func validateAssumeRoleFields(data *framework.FieldData, config *staticRoleEntry) error {
_, hasAssumeRoleARN := data.GetOk(paramAssumeRoleARN)
_, hasRoleSessionName := data.GetOk(paramRoleSessionName)
_, hasExternalID := data.GetOk(paramExternalID)

if hasAssumeRoleARN || hasRoleSessionName || hasExternalID {
return fmt.Errorf("cross-account static roles are only supported in Vault Enterprise")
}

return nil
}
17 changes: 14 additions & 3 deletions builtin/logical/aws/path_static_roles_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@ import (

"github.com/aws/aws-sdk-go/aws"
"github.com/aws/aws-sdk-go/service/iam"
"github.com/aws/aws-sdk-go/service/iam/iamiface"
"github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-secure-stdlib/awsutil"
"github.com/hashicorp/vault/sdk/framework"
"github.com/hashicorp/vault/sdk/logical"
Expand Down Expand Up @@ -97,7 +99,10 @@ func TestStaticRolesValidation(t *testing.T) {
if err != nil {
t.Fatal(err)
}
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
if err := b.Setup(bgCTX, config); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -241,7 +246,10 @@ func TestStaticRolesWrite(t *testing.T) {
}

b := Backend(config)
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}
if err := b.Setup(bgCTX, config); err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -454,7 +462,10 @@ func TestStaticRoleDelete(t *testing.T) {
}

b := Backend(config)
b.iamClient = miam
// Used to override the real IAM client creation to return the mocked client
b.nonCachedClientIAMFunc = func(ctx context.Context, s logical.Storage, logger hclog.Logger, entry *staticRoleEntry) (iamiface.IAMAPI, error) {
return miam, nil
}

// put in storage
staticRole := staticRoleEntry{
Expand Down
2 changes: 1 addition & 1 deletion builtin/logical/aws/path_user.go
Original file line number Diff line number Diff line change
Expand Up @@ -175,7 +175,7 @@ func (b *backend) pathUserRollback(ctx context.Context, req *logical.Request, _k
username := entry.UserName

// Get the client
client, err := b.clientIAM(ctx, req.Storage)
client, err := b.clientIAM(ctx, req.Storage, nil)
if err != nil {
return err
}
Expand Down
13 changes: 10 additions & 3 deletions builtin/logical/aws/rotation.go
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)
return false, nil
}

b.Logger().Debug("rotating credential", "role", item.Key)
cfg := item.Value.(staticRoleEntry)

creds, err := b.createCredential(ctx, storage, cfg, true)
Expand Down Expand Up @@ -86,9 +87,10 @@ func (b *backend) rotateCredential(ctx context.Context, storage logical.Storage)

// createCredential will create a new iam credential, deleting the oldest one if necessary.
func (b *backend) createCredential(ctx context.Context, storage logical.Storage, cfg staticRoleEntry, shouldLockStorage bool) (*awsCredentials, error) {
iamClient, err := b.clientIAM(ctx, storage)
// Always create a fresh client
iamClient, err := b.getNonCachedIAMClient(ctx, storage, cfg)
if err != nil {
return nil, fmt.Errorf("unable to get the AWS IAM client: %w", err)
return nil, fmt.Errorf("failed to get IAM client for role %q: %w", cfg.Name, err)
}

// IAM users can have a most 2 sets of keys at a time.
Expand Down Expand Up @@ -190,8 +192,13 @@ func (b *backend) deleteCredential(ctx context.Context, storage logical.Storage,
return fmt.Errorf("couldn't delete from storage: %w", err)
}

iamClient, err := b.nonCachedClientIAM(ctx, storage, b.Logger(), &cfg)
if err != nil {
return fmt.Errorf("failed to get IAM client for role %q while deleting: %w", cfg.Name, err)
}

// because we have the information, this is the one we created, so it's safe for us to delete.
_, err = b.iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{
_, err = iamClient.DeleteAccessKey(&iam.DeleteAccessKeyInput{
AccessKeyId: aws.String(creds.AccessKeyID),
UserName: aws.String(cfg.Username),
})
Expand Down
Loading

0 comments on commit 6e0c771

Please sign in to comment.