diff --git a/lib/cloud/aws/policy.go b/lib/cloud/aws/policy.go index 0d4a3ea93c016..ab258da178b80 100644 --- a/lib/cloud/aws/policy.go +++ b/lib/cloud/aws/policy.go @@ -77,7 +77,8 @@ type Statement struct { // Resources is a list of resources. Resources SliceOrString `json:"Resource,omitempty"` // Principals is a list of principals. - Principals map[string]SliceOrString `json:"Principal,omitempty"` + // It can be a single string (eg "*") or a map. + Principals StringOrMap `json:"Principal,omitempty"` // Conditions is a list of conditions that must be satisfied for the action to be allowed. // Example: // Condition: @@ -104,6 +105,52 @@ func (s *Statement) ensureResources(resources []string) { } } +// EqualStatement returns whether the receive statement is the same. +func (s *Statement) EqualStatement(other *Statement) bool { + if s.Effect != other.Effect { + return false + } + + if !slices.Equal(s.Actions, other.Actions) { + return false + } + + if len(s.Principals) != len(other.Principals) { + return false + } + + for principalKind, principalList := range s.Principals { + expectedPrincipalList := other.Principals[principalKind] + if !slices.Equal(principalList, expectedPrincipalList) { + return false + } + } + + if !slices.Equal(s.Resources, other.Resources) { + return false + } + + if len(s.Conditions) != len(other.Conditions) { + return false + } + for conditionKind, conditionOp := range s.Conditions { + expectedConditionOp := other.Conditions[conditionKind] + + if len(conditionOp) != len(expectedConditionOp) { + return false + } + + for conditionOpKind, conditionOpList := range conditionOp { + expectedConditionOpList := expectedConditionOp[conditionOpKind] + if !slices.Equal(conditionOpList, expectedConditionOpList) { + return false + } + } + } + + return true +} + // ParsePolicyDocument returns parsed AWS IAM policy document. func ParsePolicyDocument(document string) (*PolicyDocument, error) { // Policy document returned from AWS API can be URL-encoded: @@ -281,6 +328,62 @@ func (s SliceOrString) MarshalJSON() ([]byte, error) { } } +// StringOrMap defines a type that can be either a single string or a map. +// +// For almost every use case a map is used. Example: +// "Principal": { "Service": ["ecs.amazonaws.com", "elasticloadbalancing.amazonaws.com"]} +// +// For special use cases, like public/anonynous access, a "*" can be used: +// https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html#principal-anonymous +type StringOrMap map[string]SliceOrString + +// UnmarshalJSON implements json.Unmarshaller. +// If it contains a string and not a map, it will create a map with a single entry: +// { "str": [] } +// The only known example is for allowing anything, by using the "*" +// (See examples here // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html#principal-anonymous) +func (s *StringOrMap) UnmarshalJSON(bytes []byte) error { + // Check if input is a map. + var mapInput map[string]SliceOrString + mapErr := json.Unmarshal(bytes, &mapInput) + if mapErr == nil { + *s = mapInput + return nil + } + + // Check if input is a single string. + var str string + strErr := json.Unmarshal(bytes, &str) + if strErr == nil { + *s = StringOrMap{ + str: SliceOrString{}, + } + return nil + } + + // Failed both format. + return trace.NewAggregate(mapErr, strErr) +} + +// MarshalJSON implements json.Marshaler. +// It returns "*" if the map has a single key and that key has 0 items. +// The only known example is for allowing anything, by using the "*" +// (See examples here // https://docs.aws.amazon.com/IAM/latest/UserGuide/reference_policies_elements_principal.html#principal-anonymous) +// The regular Marshal method is used otherwise. +func (s StringOrMap) MarshalJSON() ([]byte, error) { + switch len(s) { + case 0: + return json.Marshal(map[string]SliceOrString{}) + case 1: + if values, isWildcard := s[wildcard]; isWildcard && len(values) == 0 { + return json.Marshal(wildcard) + } + fallthrough + default: + return json.Marshal(map[string]SliceOrString(s)) + } +} + // Policies set of IAM Policy helper functions defined as an interface to make // easier for other packages to mock and test with it. type Policies interface { diff --git a/lib/cloud/aws/policy_statements.go b/lib/cloud/aws/policy_statements.go index 5550d8d6df8e6..c4cd5b3a8a16d 100644 --- a/lib/cloud/aws/policy_statements.go +++ b/lib/cloud/aws/policy_statements.go @@ -25,7 +25,8 @@ import ( "github.com/gravitational/trace" ) -var allResources = []string{"*"} +var wildcard = "*" +var allResources = []string{wildcard} // StatementForIAMEditRolePolicy returns a IAM Policy Statement which allows editting Role Policy // of the resources. @@ -185,6 +186,23 @@ func StatementForListRDSDatabases() *Statement { } } +// StatementForS3BucketPublicRead returns the statement that +// allows public/anonynous access to s3 bucket/prefix objects. +func StatementForS3BucketPublicRead(s3bucketName, objectPrefix string) *Statement { + return &Statement{ + Effect: EffectAllow, + Principals: StringOrMap{ + wildcard: SliceOrString{}, + }, + Actions: []string{ + "s3:GetObject", + }, + Resources: []string{ + fmt.Sprintf("arn:aws:s3:::%s/%s/*", s3bucketName, objectPrefix), + }, + } +} + // ExternalAuditStoragePolicyConfig holds options for the External Audit Storage // IAM policy. type ExternalAuditStoragePolicyConfig struct { diff --git a/lib/cloud/aws/policy_test.go b/lib/cloud/aws/policy_test.go index a485faa70630e..e3c3a61a57cfa 100644 --- a/lib/cloud/aws/policy_test.go +++ b/lib/cloud/aws/policy_test.go @@ -88,6 +88,123 @@ func TestSliceOrString(t *testing.T) { }) } +func TestStringOrMap(t *testing.T) { + t.Run("marshal", func(t *testing.T) { + t.Run("nil input", func(t *testing.T) { + var empty StringOrMap + bytes, err := json.Marshal(empty) + require.NoError(t, err) + require.Equal(t, "{}", string(bytes)) + }) + + t.Run("single entity with single entry", func(t *testing.T) { + in := StringOrMap{"AWS": SliceOrString{"x"}} + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"AWS":"x"}`, string(bytes)) + }) + t.Run("single entity with multiple entries", func(t *testing.T) { + in := StringOrMap{"AWS": SliceOrString{"x", "y"}} + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"AWS":["x","y"]}`, string(bytes)) + }) + t.Run("multiple entities with multiple entries", func(t *testing.T) { + in := StringOrMap{ + "AWS": SliceOrString{"x", "y"}, + "Principal": SliceOrString{"x", "y"}, + } + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"AWS":["x","y"],"Principal":["x","y"]}`, string(bytes)) + }) + t.Run("single entity without entries", func(t *testing.T) { + in := StringOrMap{"AWS": SliceOrString{}} + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"AWS":[]}`, string(bytes)) + }) + t.Run("single entity without entries but is wildcard", func(t *testing.T) { + in := StringOrMap{"*": SliceOrString{}} + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `"*"`, string(bytes)) + }) + t.Run("wildcard but at least one entry", func(t *testing.T) { + in := StringOrMap{"*": SliceOrString{"x"}} + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"*":"x"}`, string(bytes)) + }) + t.Run("multiple entities but only one of them is wildcard", func(t *testing.T) { + in := StringOrMap{ + "*": SliceOrString{"x"}, + "Principal": SliceOrString{"x"}, + } + bytes, err := json.Marshal(in) + require.NoError(t, err) + require.Equal(t, `{"*":"x","Principal":"x"}`, string(bytes)) + }) + }) + + t.Run("unmarshal", func(t *testing.T) { + t.Run("empty map", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{}, single) + }) + t.Run("single entity with single entry", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"AWS":"x"}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{"AWS": SliceOrString{"x"}}, single) + }) + t.Run("single entity with multiple entries", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"AWS":["x","y"]}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{"AWS": SliceOrString{"x", "y"}}, single) + }) + t.Run("multiple entities with multiple entries", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"AWS":["x","y"],"Principal":["x","y"]}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{ + "AWS": SliceOrString{"x", "y"}, + "Principal": SliceOrString{"x", "y"}, + }, single) + }) + t.Run("single entity without entries", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"AWS":[]}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{"AWS": SliceOrString{}}, single) + }) + t.Run("single entity without entries but is wildcard", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`"*"`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{"*": SliceOrString{}}, single) + }) + t.Run("wildcard but at least one entry", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"*":"x"}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{"*": SliceOrString{"x"}}, single) + }) + t.Run("multiple entities but only one of them is wildcard", func(t *testing.T) { + var single StringOrMap + err := json.Unmarshal([]byte(`{"*":"x","Principal":"x"}`), &single) + require.NoError(t, err) + require.Equal(t, StringOrMap{ + "*": SliceOrString{"x"}, + "Principal": SliceOrString{"x"}, + }, single) + }) + }) +} + func TestParsePolicyDocument(t *testing.T) { t.Run("parse without principals", func(t *testing.T) { policyDoc, err := ParsePolicyDocument(`{ @@ -811,3 +928,155 @@ func (m *iamMock) PutRolePermissionsBoundaryWithContext(context.Context, *iam.Pu return &iam.PutRolePermissionsBoundaryOutput{}, nil } + +func TestEqualStatement(t *testing.T) { + for _, tt := range []struct { + name string + statementA *Statement + statementB *Statement + expected bool + }{ + { + name: "empty statement", + statementA: &Statement{}, + statementB: &Statement{}, + expected: true, + }, + { + name: "statement id is ignored", + statementA: &Statement{ + StatementID: "x", + }, + statementB: &Statement{ + StatementID: "y", + }, + expected: true, + }, + { + name: "different number of actions", + statementA: &Statement{ + Actions: SliceOrString{"x", "y"}, + }, + statementB: &Statement{ + Actions: SliceOrString{"y"}, + }, + expected: false, + }, + { + name: "different actions", + statementA: &Statement{ + Actions: SliceOrString{"x"}, + }, + statementB: &Statement{ + Actions: SliceOrString{"y"}, + }, + expected: false, + }, + { + name: "different number of principals", + statementA: &Statement{ + Principals: StringOrMap{"AWS": []string{"123456789012", "123456789013"}}, + }, + statementB: &Statement{ + Principals: StringOrMap{ + "AWS": []string{"123456789012", "123456789014"}, + "OtherPrincipal": []string{"x"}, + }, + }, + expected: false, + }, + { + name: "different principals", + statementA: &Statement{ + Principals: StringOrMap{"AWS": []string{"*"}}, + }, + statementB: &Statement{ + Principals: StringOrMap{"*": []string{}}, + }, + expected: false, + }, + { + name: "different number of conditions", + statementA: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, + "StringLike": {"s3:prefix": []string{"janedoe/*"}}, + }, + }, + statementB: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, + }, + }, + expected: false, + }, + { + name: "different conditions", + statementA: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, + }, + }, + statementB: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3601"}}, + }, + }, + expected: false, + }, + { + name: "different condition values", + statementA: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600", "3601"}}, + }, + }, + statementB: &Statement{ + Conditions: map[string]map[string]SliceOrString{ + "NumericLessThanEquals": {"aws:MultiFactorAuthAge": []string{"3600"}}, + }, + }, + expected: false, + }, + { + name: "different resource values", + statementA: &Statement{ + Resources: SliceOrString{"arn:aws:s3:::bucket-2/prefix-2/*"}, + }, + statementB: &Statement{ + Resources: SliceOrString{"arn:aws:s3:::bucket-1/*"}, + }, + expected: false, + }, + { + name: "equal statements", + statementA: &Statement{ + Effect: EffectAllow, + Principals: StringOrMap{ + wildcard: []string{}, + }, + Actions: []string{"s3:GetObject"}, + Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, + Conditions: map[string]map[string]SliceOrString{ + "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, + }, + }, + statementB: &Statement{ + Effect: EffectAllow, + Principals: StringOrMap{ + wildcard: []string{}, + }, + Actions: []string{"s3:GetObject"}, + Resources: []string{"arn:aws:s3:::my-bucket/my-prefix/*"}, + Conditions: map[string]map[string]SliceOrString{ + "StringLike": {"s3:prefix": []string{"my-prefix/*"}}, + }, + }, + expected: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + require.Equal(t, tt.expected, tt.statementA.EqualStatement(tt.statementB)) + }) + } +} diff --git a/lib/config/configuration.go b/lib/config/configuration.go index 507e943071884..a028e87c842e3 100644 --- a/lib/config/configuration.go +++ b/lib/config/configuration.go @@ -286,6 +286,22 @@ type IntegrationConfAWSOIDCIdP struct { // ProxyPublicURL is the IdP Issuer URL (Teleport Proxy Public Address). // Eg, https://.teleport.sh ProxyPublicURL string + + // S3BucketURI is the S3 URI which contains the bucket name and prefix for the issuer. + // Format: s3:/// + // Eg, s3://my-bucket/idp-teleport + // This is used in two places: + // - create openid configuration and jwks objects + // - set up the issuer + // The bucket must be public and will be created if it doesn't exist. + // + // If empty, the ProxyPublicAddress is used as issuer and no s3 objects are created. + S3BucketURI string + + // S3JWKSContentsB64 must contain the public keys for the Issuer. + // The contents must be Base64 encoded. + // Eg. base64(`{"keys":[{"kty":"RSA","alg":"RS256","n":"","e":"","use":"sig","kid":""}]}`) + S3JWKSContentsB64 string } // IntegrationConfListDatabasesIAM contains the arguments of diff --git a/lib/integrations/awsoidc/idp_iam_config.go b/lib/integrations/awsoidc/idp_iam_config.go index 27e1f3da91079..996688425df89 100644 --- a/lib/integrations/awsoidc/idp_iam_config.go +++ b/lib/integrations/awsoidc/idp_iam_config.go @@ -19,18 +19,29 @@ package awsoidc import ( + "bytes" "context" + "encoding/base64" + "encoding/json" + "fmt" + "net/http" "net/url" + "path" + "strings" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/config" "github.com/aws/aws-sdk-go-v2/service/iam" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sts" "github.com/gravitational/trace" "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" awslib "github.com/gravitational/teleport/lib/cloud/aws" + "github.com/gravitational/teleport/lib/defaults" + awsutil "github.com/gravitational/teleport/lib/utils/aws" + "github.com/gravitational/teleport/lib/utils/oidc" ) const ( @@ -54,14 +65,39 @@ type IdPIAMConfigureRequest struct { // ProxyPublicAddress is the URL to use as provider URL. // This must be a valid URL (ie, url.Parse'able) // Eg, https://.teleport.sh, https://proxy.example.org:443, https://teleport.ec2.aws:3080 + // Only one of ProxyPublicAddress or S3BucketLocation can be used. ProxyPublicAddress string + // S3BucketLocation is the S3 URI which contains the bucket name and prefix for the issuer. + // Format: s3:/// + // Eg, s3://my-bucket/idp-teleport + // This is used in two places: + // - create openid configuration and jwks objects + // - set up the issuer + // The bucket must be public and will be created if it doesn't exist. + // + // If empty, the ProxyPublicAddress is used as issuer and no s3 objects are created. + S3BucketLocation string + + // S3JWKSContentsB64 must contain the public keys for the Issuer. + // The contents must be Base64 encoded. + // Eg. base64(`{"keys":[{"kty":"RSA","alg":"RS256","n":"","e":"","use":"sig","kid":""}]}`) + S3JWKSContentsB64 string + s3Bucket string + s3BucketPrefix string + jwksFileContents []byte + // issuer is the above value but only contains the host. - // Eg, .teleport.sh, proxy.example.org, teleport.ec2.aws:3080 + // Eg, .teleport.sh, proxy.example.org, my-bucket.s3.amazonaws.com/my-prefix issuer string + // issuerURL is the full url for the issuer + // Eg, https://.teleport.sh, https://proxy.example.org, https://my-bucket.s3.amazonaws.com/my-prefix + issuerURL string // IntegrationRole is the Integration's AWS Role used to set up Teleport as an OIDC IdP. IntegrationRole string + + ownershipTags AWSTags } // CheckAndSetDefaults ensures the required fields are present. @@ -78,24 +114,61 @@ func (r *IdPIAMConfigureRequest) CheckAndSetDefaults() error { return trace.BadParameter("integration role is required") } - if r.ProxyPublicAddress == "" { - return trace.BadParameter("proxy public address is required") + if (r.ProxyPublicAddress == "" && r.S3BucketLocation == "") || (r.ProxyPublicAddress != "" && r.S3BucketLocation != "") { + return trace.BadParameter("provide only one of --proxy-public-url or --s3-bucket-uri") } - issuerURL, err := url.Parse(r.ProxyPublicAddress) - if err != nil { - return trace.BadParameter("proxy public address is not a valid url: %v", err) + if r.ProxyPublicAddress != "" { + issuerURL, err := url.Parse(r.ProxyPublicAddress) + if err != nil { + return trace.BadParameter("--proxy-public-url is not a valid url: %v", err) + } + r.issuer = issuerURL.Host + if issuerURL.Port() == "443" { + r.issuer = issuerURL.Hostname() + } + r.issuerURL = issuerURL.String() } - r.issuer = issuerURL.Host - if issuerURL.Port() == "443" { - r.issuer = issuerURL.Hostname() + + if r.S3BucketLocation != "" { + s3BucketURL, err := url.Parse(r.S3BucketLocation) + if err != nil || s3BucketURL.Scheme != "s3" { + return trace.BadParameter("--s3-bucket-uri must be valid s3 uri (eg s3://bucket/prefix)") + } + r.s3Bucket = s3BucketURL.Host + r.s3BucketPrefix = strings.TrimPrefix(s3BucketURL.Path, "/") + + r.issuer = fmt.Sprintf("%s.s3.amazonaws.com/%s", r.s3Bucket, r.s3BucketPrefix) + r.issuerURL = "https://" + r.issuer + + if len(r.S3JWKSContentsB64) == 0 { + return trace.BadParameter("--s3-jwks-base64 is required.") + } + r.jwksFileContents, err = base64.StdEncoding.DecodeString(r.S3JWKSContentsB64) + if err != nil { + return trace.BadParameter("--s3-jwks-base64 is invalid: %v", err) + } } + r.ownershipTags = defaultResourceCreationTags(r.Cluster, r.IntegrationName) + return nil } // IdPIAMConfigureClient describes the required methods to create the AWS OIDC IdP and a Role that trusts that identity provider. +// There is no guarantee that the client is thread safe. type IdPIAMConfigureClient interface { + // SetAWSRegion sets the aws region that must be used. + // This is particularly relevant for API calls that must target a specific region's endpoint. + // Eg calling S3 APIs for buckets that are in another region. + SetAWSRegion(string) + + // RegionForCreateBucket is the AWS Region that should be used to create buckets. + RegionForCreateBucket() string + + // HTTPHead performs an HTTP request for the URL using the HEAD verb. + HTTPHead(ctx context.Context, url string) (resp *http.Response, err error) + // GetCallerIdentity returns information about the caller identity. GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) @@ -104,19 +177,108 @@ type IdPIAMConfigureClient interface { // CreateRole creates a new IAM Role. CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) + + // GetRole retrieves information about the specified role, including the role's path, + // GUID, ARN, and the role's trust policy that grants permission to assume the + // role. + GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) + + // UpdateAssumeRolePolicy updates the policy that grants an IAM entity permission to assume a role. + // This is typically referred to as the "role trust policy". + UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) + + // CreateBucket creates an Amazon S3 bucket. + CreateBucket(ctx context.Context, params *s3.CreateBucketInput, optFns ...func(*s3.Options)) (*s3.CreateBucketOutput, error) + + // PutObject adds an object to a bucket. + PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) + + // HeadBucket checks if a bucket exists and if you have permission to access it. + HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) + + // GetBucketPolicy returns the policy of a specified bucket + GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error) + + // PutBucketPolicy applies an Amazon S3 bucket policy to an Amazon S3 bucket. + PutBucketPolicy(ctx context.Context, params *s3.PutBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.PutBucketPolicyOutput, error) + + // DeletePublicAccessBlock removes the PublicAccessBlock configuration for an Amazon S3 bucket. + DeletePublicAccessBlock(ctx context.Context, params *s3.DeletePublicAccessBlockInput, optFns ...func(*s3.Options)) (*s3.DeletePublicAccessBlockOutput, error) } type defaultIdPIAMConfigureClient struct { + httpClient *http.Client + *iam.Client + awsConfig aws.Config stsClient *sts.Client + s3Client *s3.Client } // GetCallerIdentity returns details about the IAM user or role whose credentials are used to call the operation. -func (d defaultIdPIAMConfigureClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { +func (d *defaultIdPIAMConfigureClient) GetCallerIdentity(ctx context.Context, params *sts.GetCallerIdentityInput, optFns ...func(*sts.Options)) (*sts.GetCallerIdentityOutput, error) { return d.stsClient.GetCallerIdentity(ctx, params, optFns...) } +// CreateBucket creates an Amazon S3 bucket. +func (d *defaultIdPIAMConfigureClient) CreateBucket(ctx context.Context, params *s3.CreateBucketInput, optFns ...func(*s3.Options)) (*s3.CreateBucketOutput, error) { + return d.s3Client.CreateBucket(ctx, params, optFns...) +} + +// PutObject adds an object to a bucket. +func (d *defaultIdPIAMConfigureClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + return d.s3Client.PutObject(ctx, params, optFns...) +} + +// HeadBucket adds an object to a bucket. +func (d *defaultIdPIAMConfigureClient) HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) { + return d.s3Client.HeadBucket(ctx, params, optFns...) +} + +// PutBucketPolicy applies an Amazon S3 bucket policy to an Amazon S3 bucket. +func (d *defaultIdPIAMConfigureClient) PutBucketPolicy(ctx context.Context, params *s3.PutBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.PutBucketPolicyOutput, error) { + return d.s3Client.PutBucketPolicy(ctx, params, optFns...) +} + +// DeletePublicAccessBlock removes the PublicAccessBlock configuration for an Amazon S3 bucket. +func (d *defaultIdPIAMConfigureClient) DeletePublicAccessBlock(ctx context.Context, params *s3.DeletePublicAccessBlockInput, optFns ...func(*s3.Options)) (*s3.DeletePublicAccessBlockOutput, error) { + return d.s3Client.DeletePublicAccessBlock(ctx, params, optFns...) +} + +// GetBucketPolicy returns the policy of a specified bucket +func (d *defaultIdPIAMConfigureClient) GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error) { + return d.s3Client.GetBucketPolicy(ctx, params, optFns...) +} + +// RegionForCreateBucket returns the region where the bucket should be created. +func (d *defaultIdPIAMConfigureClient) RegionForCreateBucket() string { + return d.awsConfig.Region +} + +// SetAWSRegion sets the aws region for next api calls. +func (d *defaultIdPIAMConfigureClient) SetAWSRegion(awsRegion string) { + if d.awsConfig.Region == awsRegion { + return + } + + d.awsConfig.Region = awsRegion + + // S3 Client is the only client that depends on the region. + d.s3Client = s3.NewFromConfig(d.awsConfig) +} + +// HTTPHead performs an HTTP request for the URL using the HEAD verb. +func (d *defaultIdPIAMConfigureClient) HTTPHead(ctx context.Context, url string) (*http.Response, error) { + req, err := http.NewRequest(http.MethodHead, url, nil) + if err != nil { + return nil, trace.Wrap(err) + } + + return d.httpClient.Do(req) +} + // NewIdPIAMConfigureClient creates a new IdPIAMConfigureClient. +// The client is not thread safe. func NewIdPIAMConfigureClient(ctx context.Context) (IdPIAMConfigureClient, error) { cfg, err := config.LoadDefaultConfig(ctx) if err != nil { @@ -127,20 +289,38 @@ func NewIdPIAMConfigureClient(ctx context.Context) (IdPIAMConfigureClient, error return nil, trace.BadParameter("failed to resolve local AWS region from environment, please set the AWS_REGION environment variable") } + httpClient, err := defaults.HTTPClient() + if err != nil { + return nil, trace.Wrap(err) + } + return &defaultIdPIAMConfigureClient{ - Client: iam.NewFromConfig(cfg), - stsClient: sts.NewFromConfig(cfg), + httpClient: httpClient, + awsConfig: cfg, + Client: iam.NewFromConfig(cfg), + stsClient: sts.NewFromConfig(cfg), + s3Client: s3.NewFromConfig(cfg), }, nil } // ConfigureIdPIAM creates a new IAM OIDC IdP in AWS. // -// The Provider URL is Teleport's Public Address. +// The Provider URL is Teleport's Public Address or the S3 bucket. // It also creates a new Role configured to trust the recently created IdP. +// If the role already exists, it will create another trust relationship for the IdP (if it doesn't exist). // // The following actions must be allowed by the IAM Role assigned in the Client. // - iam:CreateOpenIDConnectProvider // - iam:CreateRole +// - iam:GetRole +// - iam:UpdateAssumeRolePolicy +// +// If it's using the S3 bucket flow, the following are required as well: +// - s3:CreateBucket +// - s3:GetBucketPolicy +// - s3:PutBucketPolicy +// - s3:DeletePublicAccessBlock +// - s3:PutObject func ConfigureIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { if err := req.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) @@ -154,56 +334,253 @@ func ConfigureIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMC req.AccountID = aws.ToString(callerIdentity.Account) } - thumbprint, err := ThumbprintIdP(ctx, req.ProxyPublicAddress) - if err != nil { + logrus.Infof("Creating IAM OpenID Connect Provider: url=%q.", req.issuerURL) + if err := ensureOIDCIdPIAM(ctx, clt, req); err != nil { return trace.Wrap(err) } - logrus.Infof("Using the following thumbprint: %s", thumbprint) - createOIDCResp, err := clt.CreateOpenIDConnectProvider(ctx, &iam.CreateOpenIDConnectProviderInput{ + logrus.Infof("Creating IAM Role %q.", req.IntegrationRole) + if err := upsertIdPIAMRole(ctx, clt, req); err != nil { + return trace.Wrap(err) + } + + // Configuration stops here if there's no S3 bucket. + // It will use the teleport's public address as IdP issuer. + if req.s3Bucket == "" { + return nil + } + log := logrus.WithFields(logrus.Fields{ + "bucket": req.s3Bucket, + "bucket-prefix": req.s3BucketPrefix, + }) + + log.Infof("Creating bucket in region %q", clt.RegionForCreateBucket()) + if err := ensureBucketIdPIAM(ctx, clt, req, log); err != nil { + return trace.Wrap(err) + } + + log.Info("Setting public access.") + if err := ensureBucketPoliciesIdPIAM(ctx, clt, req); err != nil { + return trace.Wrap(err) + } + + log.Info("Uploading 'openid-configuration' and 'jwks' files.") + if err := uploadOpenIDPublicFiles(ctx, clt, req); err != nil { + return trace.Wrap(err) + } + + return nil +} + +func ensureOIDCIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { + var err error + // For S3 bucket setups the thumbprint is ignored, but the API still requires a parseable one. + // https://github.com/aws-actions/configure-aws-credentials/issues/357#issuecomment-1626357333 + // We pass this dummy one for those scenarios. + thumbprint := "afafafafafafafafafafafafafafafafafafafaf" + + // For set ups that use the ProxyPublicAddress, we still calculate the thumbprint. + if req.ProxyPublicAddress != "" { + thumbprint, err = ThumbprintIdP(ctx, req.ProxyPublicAddress) + if err != nil { + return trace.Wrap(err) + } + } + + _, err = clt.CreateOpenIDConnectProvider(ctx, &iam.CreateOpenIDConnectProviderInput{ ThumbprintList: []string{thumbprint}, - Url: &req.ProxyPublicAddress, + Url: &req.issuerURL, ClientIDList: []string{types.IntegrationAWSOIDCAudience}, - Tags: defaultResourceCreationTags(req.Cluster, req.IntegrationName).ToIAMTags(), + Tags: req.ownershipTags.ToIAMTags(), }) if err != nil { - if trace.IsAlreadyExists(awslib.ConvertIAMv2Error(err)) { - return trace.AlreadyExists("identity provider for the same URL (%s) already exists, please remove it and try again", req.ProxyPublicAddress) + awsErr := awslib.ConvertIAMv2Error(err) + if trace.IsAlreadyExists(awsErr) { + return nil } + return trace.Wrap(err) } - logrus.Infof("IAM OpenID Connect Provider created: url=%q arn=%q.", req.ProxyPublicAddress, aws.ToString(createOIDCResp.OpenIDConnectProviderArn)) - createdIdpIAMRoleArn, err := createIdPIAMRole(ctx, clt, req) + return nil +} + +func ensureBucketIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest, log *logrus.Entry) error { + // According to https://docs.aws.amazon.com/AmazonS3/latest/API/API_GetBucketLocation.html + // s3:GetBucketLocation is not recommended, and should be replaced by s3:HeadBucket according to AWS docs. + // The issue with using s3:HeadBucket is that it returns an error if the SDK client's region is not the same as the bucket. + // Doing a HEAD HTTP request seems to be the best option + resp, err := clt.HTTPHead(ctx, fmt.Sprintf("https://s3.amazonaws.com/%s", req.s3Bucket)) if err != nil { return trace.Wrap(err) } - logrus.Infof("IAM Role created: name=%q arn=%q", req.IntegrationRole, aws.ToString(createdIdpIAMRoleArn)) + defer resp.Body.Close() - return nil + // Even if the bucket is private, the "x-amz-bucket-region" Header will be there. + bucketRegion := resp.Header.Get("x-amz-bucket-region") + if bucketRegion != "" { + if bucketRegion == "EU" { + bucketRegion = "eu-west-1" + } + + clt.SetAWSRegion(bucketRegion) + } + + headBucketResp, err := clt.HeadBucket(ctx, &s3.HeadBucketInput{ + Bucket: &req.s3Bucket, + ExpectedBucketOwner: &req.AccountID, + }) + if err == nil { + log.Infof("Bucket already exists in %q", aws.ToString(headBucketResp.BucketRegion)) + return nil + } + awsErr := awslib.ConvertIAMv2Error(err) + if trace.IsNotFound(awsErr) { + _, err := clt.CreateBucket(ctx, &s3.CreateBucketInput{ + Bucket: &req.s3Bucket, + CreateBucketConfiguration: awsutil.CreateBucketConfiguration(clt.RegionForCreateBucket()), + }) + return trace.Wrap(err) + } + + return trace.Wrap(awsErr) } -func createIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) (*string, error) { +func ensureBucketPoliciesIdPIAM(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { + _, err := clt.DeletePublicAccessBlock(ctx, &s3.DeletePublicAccessBlockInput{ + Bucket: &req.s3Bucket, + ExpectedBucketOwner: &req.AccountID, + }) + if err != nil { + return trace.Wrap(err) + } + + bucketPolicyDoc := awslib.NewPolicyDocument() + bucketPolicyResp, err := clt.GetBucketPolicy(ctx, &s3.GetBucketPolicyInput{ + Bucket: &req.s3Bucket, + ExpectedBucketOwner: &req.AccountID, + }) + if err != nil { + // TODO(marco): this is an S3 error, not an IAM Error + awsErr := awslib.ConvertIAMv2Error(err) + // If no policy is defined yet, it will return a NotFound. + // Any other error, should be returned. + if !trace.IsNotFound(awsErr) { + return trace.Wrap(err) + } + } else { + bucketPolicyDoc, err = awslib.ParsePolicyDocument(aws.ToString(bucketPolicyResp.Policy)) + if err != nil { + return trace.Wrap(err) + } + } + + policyS3PublicRead := awslib.StatementForS3BucketPublicRead(req.s3Bucket, req.s3BucketPrefix) + for _, existingStatement := range bucketPolicyDoc.Statements { + if existingStatement.EqualStatement(policyS3PublicRead) { + return nil + } + } + + bucketPolicyDoc.Statements = append(bucketPolicyDoc.Statements, policyS3PublicRead) + newPolicyDocPublicRead, err := bucketPolicyDoc.Marshal() + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.PutBucketPolicy(ctx, &s3.PutBucketPolicyInput{ + Bucket: &req.s3Bucket, + Policy: &newPolicyDocPublicRead, + }) + + return trace.Wrap(err) +} + +func uploadOpenIDPublicFiles(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { + openidConfigPath := path.Join(req.s3BucketPrefix, ".well-known/openid-configuration") + jwksBucketPath := path.Join(req.s3BucketPrefix, ".well-known/jwks") + jwksPublicURI, err := url.JoinPath(req.issuerURL, ".well-known/jwks") + if err != nil { + return trace.Wrap(err) + } + + openIDConfigJSON, err := json.Marshal(oidc.OpenIDConfigurationForIssuer(req.issuer, jwksPublicURI)) + if err != nil { + return trace.Wrap(err) + } + _, err = clt.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &req.s3Bucket, + Key: &openidConfigPath, + Body: bytes.NewReader(openIDConfigJSON), + }) + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.PutObject(ctx, &s3.PutObjectInput{ + Bucket: &req.s3Bucket, + Key: &jwksBucketPath, + Body: bytes.NewReader(req.jwksFileContents), + }) + return trace.Wrap(err) +} + +func createIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { integrationRoleAssumeRoleDocument, err := awslib.NewPolicyDocument( awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}), ).Marshal() if err != nil { - return nil, trace.Wrap(err) + return trace.Wrap(err) } - createRoleOutput, err := clt.CreateRole(ctx, &iam.CreateRoleInput{ + _, err = clt.CreateRole(ctx, &iam.CreateRoleInput{ RoleName: &req.IntegrationRole, Description: aws.String(descriptionOIDCIdPRole), AssumeRolePolicyDocument: &integrationRoleAssumeRoleDocument, - Tags: defaultResourceCreationTags(req.Cluster, req.IntegrationName).ToIAMTags(), + Tags: req.ownershipTags.ToIAMTags(), + }) + return trace.Wrap(err) +} + +func upsertIdPIAMRole(ctx context.Context, clt IdPIAMConfigureClient, req IdPIAMConfigureRequest) error { + getRoleOut, err := clt.GetRole(ctx, &iam.GetRoleInput{ + RoleName: &req.IntegrationRole, }) if err != nil { convertedErr := awslib.ConvertIAMv2Error(err) - if trace.IsAlreadyExists(convertedErr) { - return nil, trace.AlreadyExists("Role %q already exists, please remove it and try again.", req.IntegrationRole) + if !trace.IsNotFound(convertedErr) { + return trace.Wrap(convertedErr) } - return nil, trace.Wrap(convertedErr) + + return trace.Wrap(createIdPIAMRole(ctx, clt, req)) } - return createRoleOutput.Role.Arn, nil + if !req.ownershipTags.MatchesIAMTags(getRoleOut.Role.Tags) { + return trace.BadParameter("IAM Role %q already exists but is not managed by Teleport. "+ + "Add the following tags to allow Teleport to manage this Role: %s", req.IntegrationRole, req.ownershipTags) + } + + trustRelationshipDoc, err := awslib.ParsePolicyDocument(aws.ToString(getRoleOut.Role.AssumeRolePolicyDocument)) + if err != nil { + return trace.Wrap(err) + } + + trustRelationshipForIdP := awslib.StatementForAWSOIDCRoleTrustRelationship(req.AccountID, req.issuer, []string{types.IntegrationAWSOIDCAudience}) + for _, existingStatement := range trustRelationshipDoc.Statements { + if existingStatement.EqualStatement(trustRelationshipForIdP) { + return nil + } + } + + trustRelationshipDoc.Statements = append(trustRelationshipDoc.Statements, trustRelationshipForIdP) + trustRelationshipDocString, err := trustRelationshipDoc.Marshal() + if err != nil { + return trace.Wrap(err) + } + + _, err = clt.UpdateAssumeRolePolicy(ctx, &iam.UpdateAssumeRolePolicyInput{ + RoleName: &req.IntegrationRole, + PolicyDocument: &trustRelationshipDocString, + }) + return trace.Wrap(err) } diff --git a/lib/integrations/awsoidc/idp_iam_config_test.go b/lib/integrations/awsoidc/idp_iam_config_test.go index 14ade92e26f2b..bfed23765293d 100644 --- a/lib/integrations/awsoidc/idp_iam_config_test.go +++ b/lib/integrations/awsoidc/idp_iam_config_test.go @@ -20,30 +20,48 @@ package awsoidc import ( "context" + "encoding/base64" "fmt" + "net/http" "net/http/httptest" + "net/url" "slices" + "strings" "testing" "github.com/aws/aws-sdk-go-v2/aws" "github.com/aws/aws-sdk-go-v2/service/iam" iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + "github.com/aws/aws-sdk-go-v2/service/s3" "github.com/aws/aws-sdk-go-v2/service/sts" + "github.com/gravitational/trace" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/lib" ) -var baseIdPIAMConfigReq = func() IdPIAMConfigureRequest { - return IdPIAMConfigureRequest{ - Cluster: "mycluster", - IntegrationName: "myintegration", - IntegrationRole: "integrationrole", - ProxyPublicAddress: "https://proxy.example.com", +func TestIdPIAMConfigReqDefaults(t *testing.T) { + base64EncodedString := base64.StdEncoding.EncodeToString([]byte(`jwks`)) + + baseIdPIAMConfigReqWithS3Bucket := func() IdPIAMConfigureRequest { + return IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + S3BucketLocation: "s3://bucket-1/prefix-2", + S3JWKSContentsB64: base64EncodedString, + } + } + + baseIdPIAMConfigReqWithProxy := func() IdPIAMConfigureRequest { + return IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + ProxyPublicAddress: "https://proxy.example.com", + } } -} -func TestIdPIAMConfigReqDefaults(t *testing.T) { for _, tt := range []struct { name string req func() IdPIAMConfigureRequest @@ -51,8 +69,8 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { expected IdPIAMConfigureRequest }{ { - name: "set defaults", - req: baseIdPIAMConfigReq, + name: "proxy mode: set defaults", + req: baseIdPIAMConfigReqWithProxy, errCheck: require.NoError, expected: IdPIAMConfigureRequest{ Cluster: "mycluster", @@ -60,40 +78,113 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { IntegrationRole: "integrationrole", ProxyPublicAddress: "https://proxy.example.com", issuer: "proxy.example.com", + issuerURL: "https://proxy.example.com", + ownershipTags: AWSTags{ + "teleport.dev/cluster": "mycluster", + "teleport.dev/integration": "myintegration", + "teleport.dev/origin": "integration_awsoidc", + }, }, }, { - name: "missing cluster", + name: "proxy mode: missing proxy public address", req: func() IdPIAMConfigureRequest { - req := baseIdPIAMConfigReq() - req.Cluster = "" + req := baseIdPIAMConfigReqWithProxy() + req.ProxyPublicAddress = "" return req }, errCheck: badParameterCheck, }, { - name: "missing integration name", + name: "s3 bucket mode: set defaults", + req: baseIdPIAMConfigReqWithS3Bucket, + errCheck: require.NoError, + expected: IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + S3BucketLocation: "s3://bucket-1/prefix-2", + s3Bucket: "bucket-1", + s3BucketPrefix: "prefix-2", + jwksFileContents: []byte(`jwks`), + S3JWKSContentsB64: base64EncodedString, + issuer: "bucket-1.s3.amazonaws.com/prefix-2", + issuerURL: "https://bucket-1.s3.amazonaws.com/prefix-2", + ownershipTags: AWSTags{ + "teleport.dev/cluster": "mycluster", + "teleport.dev/integration": "myintegration", + "teleport.dev/origin": "integration_awsoidc", + }, + }, + }, + { + name: "s3 bucket mode: missing jwks content", req: func() IdPIAMConfigureRequest { - req := baseIdPIAMConfigReq() - req.IntegrationName = "" + req := baseIdPIAMConfigReqWithS3Bucket() + req.S3JWKSContentsB64 = "" return req }, errCheck: badParameterCheck, }, { - name: "missing integration role", + name: "s3 bucket mode: invalid jwks content", req: func() IdPIAMConfigureRequest { - req := baseIdPIAMConfigReq() - req.IntegrationRole = "" + req := baseIdPIAMConfigReqWithS3Bucket() + req.S3JWKSContentsB64 = "x" return req }, errCheck: badParameterCheck, }, { - name: "missing proxy public address", + name: "s3 bucket mode: invalid url for s3 location", req: func() IdPIAMConfigureRequest { - req := baseIdPIAMConfigReq() - req.ProxyPublicAddress = "" + req := baseIdPIAMConfigReqWithS3Bucket() + req.S3BucketLocation = "invalid-url" + return req + }, + errCheck: badParameterCheck, + }, + { + name: "s3 bucket mode: invalid schema for s3 location", + req: func() IdPIAMConfigureRequest { + req := baseIdPIAMConfigReqWithS3Bucket() + req.S3BucketLocation = "https://proxy.example.com" + return req + }, + errCheck: badParameterCheck, + }, + { + name: "proxy and s3 bucket defined", + req: func() IdPIAMConfigureRequest { + req := baseIdPIAMConfigReqWithProxy() + req.S3BucketLocation = "s3://bucket/prefix" + return req + }, + errCheck: badParameterCheck, + }, + { + name: "missing cluster", + req: func() IdPIAMConfigureRequest { + req := baseIdPIAMConfigReqWithProxy() + req.Cluster = "" + return req + }, + errCheck: badParameterCheck, + }, + { + name: "missing integration name", + req: func() IdPIAMConfigureRequest { + req := baseIdPIAMConfigReqWithProxy() + req.IntegrationName = "" + return req + }, + errCheck: badParameterCheck, + }, + { + name: "missing integration role", + req: func() IdPIAMConfigureRequest { + req := baseIdPIAMConfigReqWithProxy() + req.IntegrationRole = "" return req }, errCheck: badParameterCheck, @@ -112,66 +203,451 @@ func TestIdPIAMConfigReqDefaults(t *testing.T) { } } -func TestConfigureIdPIAM(t *testing.T) { +func policyDocWithStatementsJSON(statement ...string) *string { + statements := strings.Join(statement, ",") + ret := fmt.Sprintf(`{ + "Version": "2012-10-17", + "Statement": [ + %s + ] + }`, statements) + return &ret +} + +func assumeRoleStatementJSON(issuer string) string { + return fmt.Sprintf(`{ + "Effect": "Allow", + "Action": "sts:AssumeRoleWithWebIdentity", + "Principal": { + "Federated": "arn:aws:iam::123456789012:oidc-provider/%s" + }, + "Condition": { + "StringEquals": { + "%s:aud": "discover.teleport" + } + } +}`, issuer, issuer) +} + +func policyStatementS3PublicAccessJSON(bucket, prefix string) string { + return fmt.Sprintf(`{ + "Effect": "Allow", + "Principal": "*", + "Action": "s3:GetObject", + "Resource": "arn:aws:s3:::%s/%s/*" +}`, bucket, prefix) +} + +func TestConfigureIdPIAMUsingProxyURL(t *testing.T) { ctx := context.Background() - tlsServer := httptest.NewTLSServer(nil) - // TLS Server starts with self-signed certificates. - lib.SetInsecureDevMode(true) - defer lib.SetInsecureDevMode(false) + t.Run("using proxy url", func(t *testing.T) { + tlsServer := httptest.NewTLSServer(nil) + tlsServerURL, err := url.Parse(tlsServer.URL) + require.NoError(t, err) - baseIdPIAMConfigReqWithTLServer := func() IdPIAMConfigureRequest { - base := baseIdPIAMConfigReq() - base.ProxyPublicAddress = tlsServer.URL - return base - } + tlsServerIssuer := tlsServerURL.Host + // TLS Server starts with self-signed certificates. - for _, tt := range []struct { - name string - mockAccountID string - mockExistingRoles []string - mockExistingIdPUrl []string - req func() IdPIAMConfigureRequest - errCheck require.ErrorAssertionFunc - }{ - { - name: "valid", - mockAccountID: "123456789012", - req: baseIdPIAMConfigReqWithTLServer, - errCheck: require.NoError, - }, - { - name: "idp url already exists", - mockAccountID: "123456789012", - mockExistingIdPUrl: []string{tlsServer.URL}, - req: baseIdPIAMConfigReqWithTLServer, - errCheck: alreadyExistsCheck, - }, - { - name: "integration role already exists", - mockAccountID: "123456789012", - mockExistingRoles: []string{"integrationrole"}, - req: baseIdPIAMConfigReqWithTLServer, - errCheck: alreadyExistsCheck, - }, - } { - t.Run(tt.name, func(t *testing.T) { - clt := mockIdPIAMConfigClient{ - accountID: tt.mockAccountID, - existingRoles: tt.mockExistingRoles, - existingIDPUrl: tt.mockExistingIdPUrl, + lib.SetInsecureDevMode(true) + defer lib.SetInsecureDevMode(false) + + baseIdPIAMConfigReqWithTLServer := func() IdPIAMConfigureRequest { + return IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + ProxyPublicAddress: tlsServer.URL, } + } - err := ConfigureIdPIAM(ctx, &clt, tt.req()) - tt.errCheck(t, err) - }) - } + for _, tt := range []struct { + name string + mockAccountID string + mockExistingRoles map[string]mockRole + mockExistingIdPUrl []string + req func() IdPIAMConfigureRequest + errCheck require.ErrorAssertionFunc + externalStateCheck func(*testing.T, mockIdPIAMConfigClient) + }{ + { + name: "valid", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithTLServer, + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{}, + errCheck: require.NoError, + }, + { + name: "idp url already exists", + mockAccountID: "123456789012", + mockExistingIdPUrl: []string{tlsServer.URL}, + mockExistingRoles: map[string]mockRole{}, + req: baseIdPIAMConfigReqWithTLServer, + errCheck: require.NoError, + }, + { + name: "role exists, no ownership tags", + mockAccountID: "123456789012", + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{"integrationrole": {}}, + req: baseIdPIAMConfigReqWithTLServer, + errCheck: badParameterCheck, + }, + { + name: "role exists, ownership tags, no assume role", + mockAccountID: "123456789012", + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{"integrationrole": { + tags: []iamTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + }, + assumeRolePolicyDoc: aws.String(`{"Version":"2012-10-17", "Statements":[]}`), + }}, + req: baseIdPIAMConfigReqWithTLServer, + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(tlsServerIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + }, + }, + { + name: "role exists, ownership tags, with existing assume role", + mockAccountID: "123456789012", + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{"integrationrole": { + tags: []iamTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + }, + assumeRolePolicyDoc: policyDocWithStatementsJSON( + assumeRoleStatementJSON("some-other-issuer"), + ), + }}, + req: baseIdPIAMConfigReqWithTLServer, + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON("some-other-issuer"), + assumeRoleStatementJSON(tlsServerIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + }, + }, + { + name: "role exists, ownership tags, assume role already exists", + mockAccountID: "123456789012", + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{"integrationrole": { + tags: []iamTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + }, + assumeRolePolicyDoc: policyDocWithStatementsJSON( + assumeRoleStatementJSON(tlsServerIssuer), + ), + }}, + req: baseIdPIAMConfigReqWithTLServer, + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(tlsServerIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + clt := mockIdPIAMConfigClient{ + accountID: tt.mockAccountID, + existingRoles: tt.mockExistingRoles, + existingIDPUrl: tt.mockExistingIdPUrl, + } + + err := ConfigureIdPIAM(ctx, &clt, tt.req()) + tt.errCheck(t, err) + + if tt.externalStateCheck != nil { + tt.externalStateCheck(t, clt) + } + }) + } + }) + + t.Run("using s3 bucket", func(t *testing.T) { + base64EncodedString := base64.StdEncoding.EncodeToString([]byte(`jwks`)) + + baseIdPIAMConfigReqWithS3Bucket := func() IdPIAMConfigureRequest { + return IdPIAMConfigureRequest{ + Cluster: "mycluster", + IntegrationName: "myintegration", + IntegrationRole: "integrationrole", + S3BucketLocation: "s3://bucket-1/prefix-2", + S3JWKSContentsB64: base64EncodedString, + } + } + expectedIssuer := "bucket-1.s3.amazonaws.com/prefix-2" + expectedIssuerURL := "https://" + expectedIssuer + + for _, tt := range []struct { + name string + mockAccountID string + mockExistingIdPUrl []string + mockExistingRoles map[string]mockRole + mockClientRegion string + mockExistingBuckets map[string]mockBucket + req func() IdPIAMConfigureRequest + errCheck require.ErrorAssertionFunc + externalStateCheck func(*testing.T, mockIdPIAMConfigClient) + }{ + { + name: "valid without any existing resources", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithS3Bucket, + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{}, + mockExistingBuckets: map[string]mockBucket{}, + mockClientRegion: "my-region", + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + // Check IdP creation + require.Contains(t, mipc.existingIDPUrl, expectedIssuerURL) + + // Check Role creation + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(expectedIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + + // Check Bucket creation + require.Contains(t, mipc.existingBuckets, "bucket-1") + bucket := mipc.existingBuckets["bucket-1"] + require.Equal(t, "my-region", bucket.region) + require.False(t, bucket.publicAccessIsBlocked) + expectedBucketPolicyDoc := policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ) + require.JSONEq(t, *expectedBucketPolicyDoc, *bucket.policyDoc) + + }, + }, + { + name: "valid with an existing IdP set up using Proxy URL", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithS3Bucket, + mockExistingIdPUrl: []string{"https://proxy.example.com"}, + mockExistingRoles: map[string]mockRole{ + "integrationrole": { + tags: []iamTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + }, + assumeRolePolicyDoc: policyDocWithStatementsJSON( + assumeRoleStatementJSON("proxy.example.com"), + ), + }, + }, + mockExistingBuckets: map[string]mockBucket{}, + mockClientRegion: "my-region", + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + // IdP should be created and the existing one must not be deleted. + require.Contains(t, mipc.existingIDPUrl, expectedIssuerURL) + require.Contains(t, mipc.existingIDPUrl, "https://proxy.example.com") + + // The role must include the new statement and must not delete the previous one + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON("proxy.example.com"), + assumeRoleStatementJSON(expectedIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + + // Check Bucket creation + require.Contains(t, mipc.existingBuckets, "bucket-1") + bucket := mipc.existingBuckets["bucket-1"] + require.Equal(t, "my-region", bucket.region) + require.False(t, bucket.publicAccessIsBlocked) + expectedBucketPolicyDoc := policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ) + require.JSONEq(t, *expectedBucketPolicyDoc, *bucket.policyDoc) + }, + }, + { + name: "bucket already exists but is on another region", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithS3Bucket, + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{}, + mockExistingBuckets: map[string]mockBucket{ + "bucket-1": { + region: "another-region", + publicAccessIsBlocked: true, + }, + }, + mockClientRegion: "my-region", + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + // Check IdP creation + require.Contains(t, mipc.existingIDPUrl, expectedIssuerURL) + + // Check Role creation + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(expectedIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + + // Check Bucket creation + require.Contains(t, mipc.existingBuckets, "bucket-1") + bucket := mipc.existingBuckets["bucket-1"] + require.False(t, bucket.publicAccessIsBlocked) + expectedBucketPolicyDoc := policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ) + require.JSONEq(t, *expectedBucketPolicyDoc, *bucket.policyDoc) + + // The last configured region must be the existing bucket's region. + require.Equal(t, "another-region", mipc.clientRegion) + }, + }, + { + name: "bucket already exists and already has a policy", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithS3Bucket, + mockExistingIdPUrl: []string{}, + mockExistingRoles: map[string]mockRole{}, + mockExistingBuckets: map[string]mockBucket{ + "bucket-1": { + region: "my-region", + publicAccessIsBlocked: true, + policyDoc: policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-2", "prefix-2"), + ), + }, + }, + mockClientRegion: "my-region", + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + // Check IdP creation + require.Contains(t, mipc.existingIDPUrl, expectedIssuerURL) + + // Check Role creation + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(expectedIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + + // Check Bucket creation + require.Contains(t, mipc.existingBuckets, "bucket-1") + bucket := mipc.existingBuckets["bucket-1"] + require.False(t, bucket.publicAccessIsBlocked) + expectedBucketPolicyDoc := policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-2", "prefix-2"), + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ) + require.JSONEq(t, *expectedBucketPolicyDoc, *bucket.policyDoc) + }, + }, + { + name: "everything already exists", + mockAccountID: "123456789012", + req: baseIdPIAMConfigReqWithS3Bucket, + mockExistingIdPUrl: []string{"https://bucket-1.s3.amazonaws.com/prefix-2"}, + mockExistingRoles: map[string]mockRole{ + "integrationrole": { + tags: []iamTypes.Tag{ + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myintegration")}, + }, + assumeRolePolicyDoc: policyDocWithStatementsJSON( + assumeRoleStatementJSON("bucket-1.s3.amazonaws.com/prefix-2"), + ), + }, + }, + mockExistingBuckets: map[string]mockBucket{ + "bucket-1": { + region: "my-region", + publicAccessIsBlocked: true, + policyDoc: policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ), + }, + }, + mockClientRegion: "my-region", + errCheck: require.NoError, + externalStateCheck: func(t *testing.T, mipc mockIdPIAMConfigClient) { + // Check IdP exists + require.Contains(t, mipc.existingIDPUrl, expectedIssuerURL) + + // Check Role exists + role := mipc.existingRoles["integrationrole"] + expectedAssumeRolePolicyDoc := policyDocWithStatementsJSON( + assumeRoleStatementJSON(expectedIssuer), + ) + require.JSONEq(t, *expectedAssumeRolePolicyDoc, aws.ToString(role.assumeRolePolicyDoc)) + + // Check Bucket exists + require.Contains(t, mipc.existingBuckets, "bucket-1") + bucket := mipc.existingBuckets["bucket-1"] + require.False(t, bucket.publicAccessIsBlocked) + expectedBucketPolicyDoc := policyDocWithStatementsJSON( + policyStatementS3PublicAccessJSON("bucket-1", "prefix-2"), + ) + require.JSONEq(t, *expectedBucketPolicyDoc, *bucket.policyDoc) + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + clt := mockIdPIAMConfigClient{ + accountID: tt.mockAccountID, + existingRoles: tt.mockExistingRoles, + existingIDPUrl: tt.mockExistingIdPUrl, + existingBuckets: tt.mockExistingBuckets, + clientRegion: tt.mockClientRegion, + } + + err := ConfigureIdPIAM(ctx, &clt, tt.req()) + tt.errCheck(t, err) + + if tt.externalStateCheck != nil { + tt.externalStateCheck(t, clt) + } + }) + } + }) +} + +type mockBucket struct { + region string + publicAccessIsBlocked bool + policyDoc *string } +type mockRole struct { + assumeRolePolicyDoc *string + tags []iamTypes.Tag +} type mockIdPIAMConfigClient struct { - accountID string - existingRoles []string - existingIDPUrl []string + clientRegion string + accountID string + existingIDPUrl []string + existingRoles map[string]mockRole + existingBuckets map[string]mockBucket } // GetCallerIdentity returns information about the caller identity. @@ -184,12 +660,16 @@ func (m *mockIdPIAMConfigClient) GetCallerIdentity(ctx context.Context, params * // CreateRole creates a new IAM Role. func (m *mockIdPIAMConfigClient) CreateRole(ctx context.Context, params *iam.CreateRoleInput, optFns ...func(*iam.Options)) (*iam.CreateRoleOutput, error) { alreadyExistsMessage := fmt.Sprintf("Role %q already exists.", *params.RoleName) - if slices.Contains(m.existingRoles, *params.RoleName) { + _, found := m.existingRoles[aws.ToString(params.RoleName)] + if found { return nil, &iamTypes.EntityAlreadyExistsException{ Message: &alreadyExistsMessage, } } - m.existingRoles = append(m.existingRoles, *params.RoleName) + m.existingRoles[*params.RoleName] = mockRole{ + tags: params.Tags, + assumeRolePolicyDoc: params.AssumeRolePolicyDocument, + } return &iam.CreateRoleOutput{ Role: &iamTypes.Role{ @@ -206,10 +686,146 @@ func (m *mockIdPIAMConfigClient) CreateOpenIDConnectProvider(ctx context.Context Message: &alreadyExistsMessage, } } - m.existingIDPUrl = append(m.existingRoles, *params.Url) + m.existingIDPUrl = append(m.existingIDPUrl, *params.Url) + + return &iam.CreateOpenIDConnectProviderOutput{}, nil +} + +// GetRole retrieves information about the specified role, including the role's path, +// GUID, ARN, and the role's trust policy that grants permission to assume the +// role. +func (m *mockIdPIAMConfigClient) GetRole(ctx context.Context, params *iam.GetRoleInput, optFns ...func(*iam.Options)) (*iam.GetRoleOutput, error) { + role, found := m.existingRoles[aws.ToString(params.RoleName)] + if !found { + return nil, trace.NotFound("role not found") + } + return &iam.GetRoleOutput{ + Role: &iamTypes.Role{ + Tags: role.tags, + AssumeRolePolicyDocument: role.assumeRolePolicyDoc, + }, + }, nil +} - return &iam.CreateOpenIDConnectProviderOutput{ - OpenIDConnectProviderArn: aws.String("arn:something"), +// UpdateAssumeRolePolicy updates the policy that grants an IAM entity permission to assume a role. +// This is typically referred to as the "role trust policy". +func (m *mockIdPIAMConfigClient) UpdateAssumeRolePolicy(ctx context.Context, params *iam.UpdateAssumeRolePolicyInput, optFns ...func(*iam.Options)) (*iam.UpdateAssumeRolePolicyOutput, error) { + role, found := m.existingRoles[aws.ToString(params.RoleName)] + if !found { + return nil, trace.NotFound("role not found") + } + + role.assumeRolePolicyDoc = params.PolicyDocument + m.existingRoles[aws.ToString(params.RoleName)] = role + + return &iam.UpdateAssumeRolePolicyOutput{}, nil +} + +// CreateBucket creates an Amazon S3 bucket. +func (m *mockIdPIAMConfigClient) CreateBucket(ctx context.Context, params *s3.CreateBucketInput, optFns ...func(*s3.Options)) (*s3.CreateBucketOutput, error) { + m.existingBuckets[*params.Bucket] = mockBucket{ + publicAccessIsBlocked: true, + region: m.clientRegion, + } + return nil, nil +} + +// PutObject adds an object to a bucket. +func (m *mockIdPIAMConfigClient) PutObject(ctx context.Context, params *s3.PutObjectInput, optFns ...func(*s3.Options)) (*s3.PutObjectOutput, error) { + return nil, nil +} + +// HeadBucket adds an object to a bucket. +func (m *mockIdPIAMConfigClient) HeadBucket(ctx context.Context, params *s3.HeadBucketInput, optFns ...func(*s3.Options)) (*s3.HeadBucketOutput, error) { + bucket, found := m.existingBuckets[*params.Bucket] + if !found { + return nil, trace.NotFound("bucket does not exist") + } + + return &s3.HeadBucketOutput{ + BucketRegion: &bucket.region, + }, nil +} + +// RegionForCreateBucket returns the default aws region to use when creating a bucket. +func (m *mockIdPIAMConfigClient) RegionForCreateBucket() string { + return m.clientRegion +} + +// SetAWSRegion sets the default aws region to use. +func (m *mockIdPIAMConfigClient) SetAWSRegion(awsRegion string) { + m.clientRegion = awsRegion +} + +// PutBucketPolicy applies an Amazon S3 bucket policy to an Amazon S3 bucket. +func (m *mockIdPIAMConfigClient) PutBucketPolicy(ctx context.Context, params *s3.PutBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.PutBucketPolicyOutput, error) { + bucket, found := m.existingBuckets[*params.Bucket] + if !found { + return nil, trace.NotFound("bucket does not exist") + } + + bucket.policyDoc = params.Policy + m.existingBuckets[*params.Bucket] = bucket + + return &s3.PutBucketPolicyOutput{}, nil +} + +// DeletePublicAccessBlock removes the PublicAccessBlock configuration for an Amazon S3 bucket. +func (m *mockIdPIAMConfigClient) DeletePublicAccessBlock(ctx context.Context, params *s3.DeletePublicAccessBlockInput, optFns ...func(*s3.Options)) (*s3.DeletePublicAccessBlockOutput, error) { + bucket, found := m.existingBuckets[*params.Bucket] + if !found { + return nil, trace.NotFound("bucket does not exist") + } + + bucket.publicAccessIsBlocked = false + m.existingBuckets[*params.Bucket] = bucket + + return &s3.DeletePublicAccessBlockOutput{}, nil +} + +// GetBucketPolicy returns the policy of a specified bucket +func (m *mockIdPIAMConfigClient) GetBucketPolicy(ctx context.Context, params *s3.GetBucketPolicyInput, optFns ...func(*s3.Options)) (*s3.GetBucketPolicyOutput, error) { + bucket, found := m.existingBuckets[*params.Bucket] + if !found { + return nil, trace.NotFound("bucket does not exist") + } + + if bucket.policyDoc == nil { + return nil, trace.NotFound("policy not set yet") + } + + return &s3.GetBucketPolicyOutput{ + Policy: bucket.policyDoc, + }, nil +} + +// HTTPHead does an HEAD HTTP Request to the target URL. +func (m *mockIdPIAMConfigClient) HTTPHead(ctx context.Context, endpoint string) (*http.Response, error) { + endpointURL, err := url.Parse(endpoint) + if err != nil { + return nil, trace.Wrap(err) + } + + // check if bucket exists + // expected URL is: https://s3.amazonaws.com// + endpointURLPath := strings.TrimLeft(endpointURL.Path, "/") + bucketName := strings.Split(endpointURLPath, "/")[0] + + bucket, found := m.existingBuckets[bucketName] + if !found { + return &http.Response{ + StatusCode: http.StatusNotFound, + Body: http.NoBody, + }, nil + } + + m.clientRegion = bucket.region + + return &http.Response{ + Header: http.Header{ + "x-amz-bucket-region": []string{bucket.region}, + }, + Body: http.NoBody, }, nil } diff --git a/lib/integrations/awsoidc/tags.go b/lib/integrations/awsoidc/tags.go index 37e7747de937b..7110df2b18dde 100644 --- a/lib/integrations/awsoidc/tags.go +++ b/lib/integrations/awsoidc/tags.go @@ -97,6 +97,23 @@ func (d AWSTags) MatchesECSTags(resourceTags []ecsTypes.Tag) bool { return true } +// MatchesIAMTags checks if the AWSTags are present and have the same value in resourceTags. +func (d AWSTags) MatchesIAMTags(resourceTags []iamTypes.Tag) bool { + resourceTagsMap := make(map[string]string, len(resourceTags)) + for _, tag := range resourceTags { + resourceTagsMap[*tag.Key] = *tag.Value + } + + for awsTagKey, awsTagValue := range d { + resourceTagValue, found := resourceTagsMap[awsTagKey] + if !found || resourceTagValue != awsTagValue { + return false + } + } + + return true +} + // ToIAMTags returns the default tags using the expected type for IAM resources: [iamTypes.Tag] func (d AWSTags) ToIAMTags() []iamTypes.Tag { iamTags := make([]iamTypes.Tag, 0, len(d)) diff --git a/lib/integrations/awsoidc/tags_test.go b/lib/integrations/awsoidc/tags_test.go index 7ff2b3f23ebc4..899235600ffe1 100644 --- a/lib/integrations/awsoidc/tags_test.go +++ b/lib/integrations/awsoidc/tags_test.go @@ -22,9 +22,9 @@ import ( "testing" "github.com/aws/aws-sdk-go-v2/aws" - ec2Types "github.com/aws/aws-sdk-go-v2/service/ec2/types" - ecsTypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" - iamTypes "github.com/aws/aws-sdk-go-v2/service/iam/types" + ec2types "github.com/aws/aws-sdk-go-v2/service/ec2/types" + ecstypes "github.com/aws/aws-sdk-go-v2/service/ecs/types" + iamtypes "github.com/aws/aws-sdk-go-v2/service/iam/types" "github.com/stretchr/testify/require" ) @@ -41,7 +41,7 @@ func TestDefaultTags(t *testing.T) { require.Equal(t, expectedTags, d) t.Run("iam tags", func(t *testing.T) { - expectedIAMTags := []iamTypes.Tag{ + expectedIAMTags := []iamtypes.Tag{ {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, @@ -50,7 +50,7 @@ func TestDefaultTags(t *testing.T) { }) t.Run("ecs tags", func(t *testing.T) { - expectedECSTags := []ecsTypes.Tag{ + expectedECSTags := []ecstypes.Tag{ {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, @@ -59,7 +59,7 @@ func TestDefaultTags(t *testing.T) { }) t.Run("ec2 tags", func(t *testing.T) { - expectedEC2Tags := []ec2Types.Tag{ + expectedEC2Tags := []ec2types.Tag{ {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, @@ -68,37 +68,73 @@ func TestDefaultTags(t *testing.T) { }) t.Run("resource is teleport managed", func(t *testing.T) { - t.Run("all tags match", func(t *testing.T) { - awsResourceTags := []ecsTypes.Tag{ - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, - {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, - } - require.True(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") + t.Run("ECS Tags", func(t *testing.T) { + t.Run("all tags match", func(t *testing.T) { + awsResourceTags := []ecstypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + require.True(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") + }) + t.Run("extra tags in aws resource", func(t *testing.T) { + awsResourceTags := []ecstypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("unrelated"), Value: aws.String("true")}, + } + require.True(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") + }) + t.Run("missing one of the labels should return false", func(t *testing.T) { + awsResourceTags := []ecstypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + } + require.False(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as Teleport managed") + }) + t.Run("one of the labels has a different value, should return false", func(t *testing.T) { + awsResourceTags := []ecstypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("another-cluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + require.False(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as Teleport managed") + }) }) - t.Run("extra tags in aws resource", func(t *testing.T) { - awsResourceTags := []ecsTypes.Tag{ - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, - {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, - {Key: aws.String("unrelated"), Value: aws.String("true")}, - } - require.True(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") - }) - t.Run("missing one of the labels should return false", func(t *testing.T) { - awsResourceTags := []ecsTypes.Tag{ - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, - } - require.False(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as Teleport managed") - }) - t.Run("one of the labels has a different value, should return false", func(t *testing.T) { - awsResourceTags := []ecsTypes.Tag{ - {Key: aws.String("teleport.dev/cluster"), Value: aws.String("another-cluster")}, - {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, - {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, - } - require.False(t, d.MatchesECSTags(awsResourceTags), "resource was wrongly detected as Teleport managed") + t.Run("IAM Tags", func(t *testing.T) { + t.Run("all tags match", func(t *testing.T) { + awsResourceTags := []iamtypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + require.True(t, d.MatchesIAMTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") + }) + t.Run("extra tags in aws resource", func(t *testing.T) { + awsResourceTags := []iamtypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + {Key: aws.String("unrelated"), Value: aws.String("true")}, + } + require.True(t, d.MatchesIAMTags(awsResourceTags), "resource was wrongly detected as not Teleport managed") + }) + t.Run("missing one of the labels should return false", func(t *testing.T) { + awsResourceTags := []iamtypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("mycluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + } + require.False(t, d.MatchesIAMTags(awsResourceTags), "resource was wrongly detected as Teleport managed") + }) + t.Run("one of the labels has a different value, should return false", func(t *testing.T) { + awsResourceTags := []iamtypes.Tag{ + {Key: aws.String("teleport.dev/cluster"), Value: aws.String("another-cluster")}, + {Key: aws.String("teleport.dev/integration"), Value: aws.String("myawsaccount")}, + {Key: aws.String("teleport.dev/origin"), Value: aws.String("integration_awsoidc")}, + } + require.False(t, d.MatchesIAMTags(awsResourceTags), "resource was wrongly detected as Teleport managed") + }) }) }) } diff --git a/lib/integrations/externalauditstorage/bootstrap.go b/lib/integrations/externalauditstorage/bootstrap.go index b50ded240ca5f..62e1cc55bbebd 100644 --- a/lib/integrations/externalauditstorage/bootstrap.go +++ b/lib/integrations/externalauditstorage/bootstrap.go @@ -205,7 +205,7 @@ func createTransientBucket(ctx context.Context, clt BootstrapS3Client, bucketNam func createBucket(ctx context.Context, clt BootstrapS3Client, bucketName string, region string, objectLock bool) error { _, err := clt.CreateBucket(ctx, &s3.CreateBucketInput{ Bucket: &bucketName, - CreateBucketConfiguration: createBucketConfiguration(region), + CreateBucketConfiguration: awsutil.CreateBucketConfiguration(region), ObjectLockEnabledForBucket: aws.Bool(objectLock), ACL: s3types.BucketCannedACLPrivate, ObjectOwnership: s3types.ObjectOwnershipBucketOwnerEnforced, @@ -223,18 +223,6 @@ func createBucket(ctx context.Context, clt BootstrapS3Client, bucketName string, return trace.Wrap(awsutil.ConvertS3Error(err), "setting versioning configuration on S3 bucket") } -func createBucketConfiguration(region string) *s3types.CreateBucketConfiguration { - // No location constraint wanted for us-east-1 because it is the default and - // AWS has decided, in all their infinite wisdom, that the CreateBucket API - // should fail if you explicitly pass the default location constraint. - if region == "us-east-1" { - return nil - } - return &s3types.CreateBucketConfiguration{ - LocationConstraint: s3types.BucketLocationConstraint(region), - } -} - // createAthenaWorkgroup creates an athena workgroup in which to run athena sql queries. func createAthenaWorkgroup(ctx context.Context, clt BootstrapAthenaClient, workgroup string) error { fmt.Printf("Creating Athena workgroup %s\n", workgroup) diff --git a/lib/utils/aws/s3.go b/lib/utils/aws/s3.go index b2851d89ea687..5ac405a816de1 100644 --- a/lib/utils/aws/s3.go +++ b/lib/utils/aws/s3.go @@ -27,7 +27,7 @@ import ( awsV2 "github.com/aws/aws-sdk-go-v2/aws" managerV2 "github.com/aws/aws-sdk-go-v2/feature/s3/manager" s3v2 "github.com/aws/aws-sdk-go-v2/service/s3" - s3Types "github.com/aws/aws-sdk-go-v2/service/s3/types" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/aws/aws-sdk-go/aws/awserr" "github.com/aws/aws-sdk-go/service/s3" "github.com/gravitational/trace" @@ -59,27 +59,27 @@ func ConvertS3Error(err error, args ...interface{}) error { } // SDK v2 errors: - var noSuchKey *s3Types.NoSuchKey + var noSuchKey *s3types.NoSuchKey if errors.As(err, &noSuchKey) { return trace.NotFound(noSuchKey.Error(), args...) } - var noSuchBucket *s3Types.NoSuchBucket + var noSuchBucket *s3types.NoSuchBucket if errors.As(err, &noSuchBucket) { return trace.NotFound(noSuchBucket.Error(), args...) } - var noSuchUpload *s3Types.NoSuchUpload + var noSuchUpload *s3types.NoSuchUpload if errors.As(err, &noSuchUpload) { return trace.NotFound(noSuchUpload.Error(), args...) } - var bucketAlreadyExists *s3Types.BucketAlreadyExists + var bucketAlreadyExists *s3types.BucketAlreadyExists if errors.As(err, &bucketAlreadyExists) { return trace.AlreadyExists(bucketAlreadyExists.Error(), args...) } - var bucketAlreadyOwned *s3Types.BucketAlreadyOwnedByYou + var bucketAlreadyOwned *s3types.BucketAlreadyOwnedByYou if errors.As(err, &bucketAlreadyOwned) { return trace.AlreadyExists(bucketAlreadyOwned.Error(), args...) } - var notFound *s3Types.NotFound + var notFound *s3types.NotFound if errors.As(err, ¬Found) { return trace.NotFound(notFound.Error(), args...) } @@ -147,3 +147,16 @@ func (s *s3V2FileWriter) Close() error { rCloseErr := s.pipeReader.Close() return trace.Wrap(trace.NewAggregate(wCloseErr, readerErr, rCloseErr)) } + +// CreateBucketConfiguration creates the default CreateBucketConfiguration. +func CreateBucketConfiguration(region string) *s3types.CreateBucketConfiguration { + // No location constraint wanted for us-east-1 because it is the default and + // AWS has decided, in all their infinite wisdom, that the CreateBucket API + // should fail if you explicitly pass the default location constraint. + if region == "us-east-1" { + return nil + } + return &s3types.CreateBucketConfiguration{ + LocationConstraint: s3types.BucketLocationConstraint(region), + } +} diff --git a/lib/utils/aws/s3_test.go b/lib/utils/aws/s3_test.go index c94f8025fa5a6..420525cecdc9e 100644 --- a/lib/utils/aws/s3_test.go +++ b/lib/utils/aws/s3_test.go @@ -29,6 +29,7 @@ import ( "github.com/aws/aws-sdk-go-v2/feature/s3/manager" "github.com/aws/aws-sdk-go-v2/service/s3" + s3types "github.com/aws/aws-sdk-go-v2/service/s3/types" "github.com/stretchr/testify/require" ) @@ -112,3 +113,36 @@ func (s *s3ClientMock) PutObject(ctx context.Context, in *s3.PutObjectInput, opt return &s3.PutObjectOutput{}, nil } } + +func TestCreateBucketConfiguration(t *testing.T) { + for _, tt := range []struct { + name string + regionIn string + expected *s3types.CreateBucketConfiguration + }{ + { + name: "special region", + regionIn: "us-east-1", + expected: nil, + }, + { + name: "regular region", + regionIn: "us-east-2", + expected: &s3types.CreateBucketConfiguration{ + LocationConstraint: s3types.BucketLocationConstraintUsEast2, + }, + }, + { + name: "unknown region", + regionIn: "unknown", + expected: &s3types.CreateBucketConfiguration{ + LocationConstraint: "unknown", + }, + }, + } { + t.Run(tt.name, func(t *testing.T) { + got := CreateBucketConfiguration(tt.regionIn) + require.Equal(t, tt.expected, got) + }) + } +} diff --git a/lib/utils/oidc/openidconfig.go b/lib/utils/oidc/openidconfig.go new file mode 100644 index 0000000000000..e4ec68ac15828 --- /dev/null +++ b/lib/utils/oidc/openidconfig.go @@ -0,0 +1,44 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package oidc + +// OpenIDConfiguration is the default OpenID Configuration used by Teleport. +type OpenIDConfiguration struct { + Issuer string `json:"issuer"` + JWKSURI string `json:"jwks_uri"` + Claims []string `json:"claims"` + IdTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` + ResponseTypesSupported []string `json:"response_types_supported"` + ScopesSupported []string `json:"scopes_supported"` + SubjectTypesSupported []string `json:"subject_types_supported"` +} + +// OpenIDConfigurationForIssuer returns the OpenID Configuration for +// the given issuer and JWKS URI. +func OpenIDConfigurationForIssuer(issuer, jwksURI string) OpenIDConfiguration { + return OpenIDConfiguration{ + Issuer: issuer, + JWKSURI: jwksURI, + Claims: []string{"iss", "sub", "obo", "aud", "jti", "iat", "exp", "nbf"}, + IdTokenSigningAlgValuesSupported: []string{"RS256"}, + ResponseTypesSupported: []string{"id_token"}, + ScopesSupported: []string{"openid"}, + SubjectTypesSupported: []string{"public", "pair-wise"}, + } +} diff --git a/lib/utils/oidc/openidconfig_test.go b/lib/utils/oidc/openidconfig_test.go new file mode 100644 index 0000000000000..8796a978d0e1e --- /dev/null +++ b/lib/utils/oidc/openidconfig_test.go @@ -0,0 +1,41 @@ +/* + * Teleport + * Copyright (C) 2024 Gravitational, Inc. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU Affero General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU Affero General Public License for more details. + * + * You should have received a copy of the GNU Affero General Public License + * along with this program. If not, see . + */ + +package oidc + +import ( + "testing" + + "github.com/stretchr/testify/require" +) + +func TestOpenIDConfigurationForIssuer(t *testing.T) { + expected := OpenIDConfiguration{ + Issuer: "https://localhost:8080", + JWKSURI: "https://localhost:8080/.well-known/jwks", + Claims: []string{"iss", "sub", "obo", "aud", "jti", "iat", "exp", "nbf"}, + IdTokenSigningAlgValuesSupported: []string{"RS256"}, + ResponseTypesSupported: []string{"id_token"}, + ScopesSupported: []string{"openid"}, + SubjectTypesSupported: []string{"public", "pair-wise"}, + } + + got := OpenIDConfigurationForIssuer("https://localhost:8080", "https://localhost:8080/.well-known/jwks") + require.Equal(t, expected, got) + +} diff --git a/lib/web/oidcidp.go b/lib/web/oidcidp.go index d7c794dfd6b65..57badf2c09fe4 100644 --- a/lib/web/oidcidp.go +++ b/lib/web/oidcidp.go @@ -42,23 +42,7 @@ func (h *Handler) openidConfiguration(_ http.ResponseWriter, _ *http.Request, _ return nil, trace.Wrap(err) } - return struct { - Issuer string `json:"issuer"` - JWKSURI string `json:"jwks_uri"` - Claims []string `json:"claims"` - IdTokenSigningAlgValuesSupported []string `json:"id_token_signing_alg_values_supported"` - ResponseTypesSupported []string `json:"response_types_supported"` - ScopesSupported []string `json:"scopes_supported"` - SubjectTypesSupported []string `json:"subject_types_supported"` - }{ - Issuer: issuer, - JWKSURI: issuer + OIDCJWKWURI, - Claims: []string{"iss", "sub", "obo", "aud", "jti", "iat", "exp", "nbf"}, - IdTokenSigningAlgValuesSupported: []string{"RS256"}, - ResponseTypesSupported: []string{"id_token"}, - ScopesSupported: []string{"openid"}, - SubjectTypesSupported: []string{"public", "pair-wise"}, - }, nil + return oidc.OpenIDConfigurationForIssuer(issuer, issuer+OIDCJWKWURI), nil } // jwksOIDC returns all public keys used to sign JWT tokens for this cluster. diff --git a/tool/teleport/common/teleport.go b/tool/teleport/common/teleport.go index 7d6afb9549b6c..ab5499d1861f3 100644 --- a/tool/teleport/common/teleport.go +++ b/tool/teleport/common/teleport.go @@ -486,9 +486,13 @@ func Run(options Options) (app *kingpin.Application, executedCommand string, con IntegrationConfAWSOIDCIdPArguments.Name) integrationConfAWSOIDCIdPCmd.Flag("role", "The AWS Role used by the AWS OIDC Integration.").Required().StringVar(&ccf. IntegrationConfAWSOIDCIdPArguments.Role) - integrationConfAWSOIDCIdPCmd.Flag("proxy-public-url", "Proxy Public URL (eg https://mytenant.teleport.sh).").Required().StringVar(&ccf. + integrationConfAWSOIDCIdPCmd.Flag("proxy-public-url", "Proxy Public URL (eg https://mytenant.teleport.sh).").StringVar(&ccf. IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL) integrationConfAWSOIDCIdPCmd.Flag("insecure", "Insecure mode disables certificate validation.").BoolVar(&ccf.InsecureMode) + integrationConfAWSOIDCIdPCmd.Flag("s3-bucket-uri", "The S3 URI(format: s3:///) used to store the OpenID configuration and public keys. ").StringVar(&ccf. + IntegrationConfAWSOIDCIdPArguments.S3BucketURI) + integrationConfAWSOIDCIdPCmd.Flag("s3-jwks-base64", `The JWKS base 64 encoded. Required when using the S3 Bucket as the Issuer URL. Format: base64({"keys":[{"kty":"RSA","alg":"RS256","n":"","e":"","use":"sig","kid":""}]}).`).StringVar(&ccf. + IntegrationConfAWSOIDCIdPArguments.S3JWKSContentsB64) integrationConfListDatabasesCmd := integrationConfigureCmd.Command("listdatabases-iam", "Adds required IAM permissions to List RDS Databases (Instances and Clusters).") integrationConfListDatabasesCmd.Flag("aws-region", "AWS Region.").Required().StringVar(&ccf.IntegrationConfListDatabasesIAMArguments.Region) @@ -1031,6 +1035,8 @@ func onIntegrationConfAWSOIDCIdP(clf config.CommandLineFlags) error { IntegrationName: clf.IntegrationConfAWSOIDCIdPArguments.Name, IntegrationRole: clf.IntegrationConfAWSOIDCIdPArguments.Role, ProxyPublicAddress: clf.IntegrationConfAWSOIDCIdPArguments.ProxyPublicURL, + S3BucketLocation: clf.IntegrationConfAWSOIDCIdPArguments.S3BucketURI, + S3JWKSContentsB64: clf.IntegrationConfAWSOIDCIdPArguments.S3JWKSContentsB64, }) if err != nil { return trace.Wrap(err)