diff --git a/lib/tbot/service_workload_identity_aws_ra.go b/lib/tbot/service_workload_identity_aws_ra.go index 7c29b852b31b5..5cf33b0331533 100644 --- a/lib/tbot/service_workload_identity_aws_ra.go +++ b/lib/tbot/service_workload_identity_aws_ra.go @@ -91,7 +91,14 @@ func (s *WorkloadIdentityAWSRAService) generate(ctx context.Context) error { if err != nil { return trace.Wrap(err, "marshaling private key") } - svid, err := x509svid.ParseRaw(res.GetX509Svid().Cert, pkcs8) + certWithChain := new(bytes.Buffer) + _, _ = certWithChain.Write(res.GetX509Svid().GetCert()) + // If external PKI is configured, we need to append the chain to the leaf + // certificate before calling x509svid.ParseRaw. + for _, cert := range res.GetX509Svid().GetChain() { + _, _ = certWithChain.Write(cert) + } + svid, err := x509svid.ParseRaw(certWithChain.Bytes(), pkcs8) if err != nil { return trace.Wrap(err, "parsing x509 svid") } diff --git a/lib/tbot/service_workload_identity_aws_ra_test.go b/lib/tbot/service_workload_identity_aws_ra_test.go index 80eb7e09f905a..b718f0dbe1540 100644 --- a/lib/tbot/service_workload_identity_aws_ra_test.go +++ b/lib/tbot/service_workload_identity_aws_ra_test.go @@ -36,6 +36,7 @@ import ( "github.com/gravitational/teleport/api/types" apiutils "github.com/gravitational/teleport/api/utils" "github.com/gravitational/teleport/lib/tbot/config" + "github.com/gravitational/teleport/lib/tlsca" "github.com/gravitational/teleport/lib/utils" "github.com/gravitational/teleport/lib/utils/testutils/golden" "github.com/gravitational/teleport/tool/teleport/testenv" @@ -145,48 +146,97 @@ func TestBotWorkloadIdentityAWSRA(t *testing.T) { ctx := context.Background() log := utils.NewSlogLoggerForTests() - process := testenv.MakeTestServer(t, defaultTestServerOpts(t, log)) - rootClient := testenv.MakeDefaultAuthClient(t, process) + tests := []struct { + name string + externalPKI bool + }{ + { + name: "no external pki", + }, + { + name: "external pki", + externalPKI: true, + }, + } + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + process := testenv.MakeTestServer(t, defaultTestServerOpts(t, log)) + if tt.externalPKI { + setWorkloadIdentityX509CAOverride(ctx, t, process) + } + spiffeCA, err := process.GetAuthServer(). + GetCertAuthority(ctx, types.CertAuthID{ + DomainName: "root", + Type: types.SPIFFECA, + }, false) + require.NoError(t, err) + spiffeCAX509KeyPairs := spiffeCA.GetTrustedTLSKeyPairs() + require.Len(t, spiffeCAX509KeyPairs, 1) + spiffeCACert, err := tlsca.ParseCertificatePEM(spiffeCAX509KeyPairs[0].Cert) + require.NoError(t, err) + rootClient := testenv.MakeDefaultAuthClient(t, process) + + roleArn := "arn:aws:iam::123456789012:role/example-role" + trustAnchorArn := "arn:aws:rolesanywhere:us-east-1:123456789012:trust-anchor/0000000-0000-0000-0000-000000000000" + profileArn := "arn:aws:rolesanywhere:us-east-1:123456789012:profile/0000000-0000-0000-0000-00000000000" + srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + assert.Equal(t, "/sessions", r.URL.Path) + assert.Equal(t, http.MethodPost, r.Method) - roleArn := "arn:aws:iam::123456789012:role/example-role" - trustAnchorArn := "arn:aws:rolesanywhere:us-east-1:123456789012:trust-anchor/0000000-0000-0000-0000-000000000000" - profileArn := "arn:aws:rolesanywhere:us-east-1:123456789012:profile/0000000-0000-0000-0000-00000000000" - srv := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - assert.Equal(t, "/sessions", r.URL.Path) - assert.Equal(t, http.MethodPost, r.Method) + // Check query parameter inputs + // The AWS documentation "lies" about these inputs using the JSON body + // - the rolesanywhere API client in + // `aws/rolesanywhere-credential-helper` uses query parameters for + // these. + assert.Equal(t, roleArn, r.URL.Query().Get("roleArn")) + assert.Equal(t, trustAnchorArn, r.URL.Query().Get("trustAnchorArn")) + assert.Equal(t, profileArn, r.URL.Query().Get("profileArn")) - // Check query parameter inputs - // The AWS documentation "lies" about these inputs using the JSON body - // - the rolesanywhere API client in - // `aws/rolesanywhere-credential-helper` uses query parameters for - // these. - assert.Equal(t, roleArn, r.URL.Query().Get("roleArn")) - assert.Equal(t, trustAnchorArn, r.URL.Query().Get("trustAnchorArn")) - assert.Equal(t, profileArn, r.URL.Query().Get("profileArn")) + // Check JSON body inputs + body := &mockCreateSessionInputBody{} + assert.NoError(t, json.NewDecoder(r.Body).Decode(body)) + assert.Equal(t, int((2 * time.Hour).Seconds()), body.DurationSeconds) - // Check JSON body inputs - body := &mockCreateSessionInputBody{} - assert.NoError(t, json.NewDecoder(r.Body).Decode(body)) - assert.Equal(t, int((2 * time.Hour).Seconds()), body.DurationSeconds) + // Validate the X-Amz-X509 header contains the valid (and correct) SVID + derString := r.Header.Get("X-Amz-X509") + assert.NotEmpty(t, derString) + derBytes, err := base64.StdEncoding.DecodeString(derString) + assert.NoError(t, err) + cert, err := x509.ParseCertificate(derBytes) + assert.NoError(t, err) + assert.Len(t, cert.URIs, 1) + assert.Equal(t, "spiffe://root/ra-test", cert.URIs[0].String()) - // Validate the X-Amz-X509 header contains the valid (and correct) SVID - derString := r.Header.Get("X-Amz-X509") - assert.NotEmpty(t, derString) - derBytes, err := base64.StdEncoding.DecodeString(derString) - assert.NoError(t, err) - cert, err := x509.ParseCertificate(derBytes) - assert.NoError(t, err) - assert.Len(t, cert.URIs, 1) - assert.Equal(t, "spiffe://root/ra-test", cert.URIs[0].String()) + // Validate the X-Amz-X509-Chain header contains the valid chain + chainString := r.Header.Get("X-Amz-X509-Chain") + if tt.externalPKI { + require.NotEmpty(t, chainString) + // If there were multiple certs in the chain, we'd need to + // split by comma first since: + // + // > The X-Amz-X509-Chain header MUST be encoded as + // > comma-delimited, base64-encoded DER + // + // But since we only expect a single item in the chain here + // we can just decode it. + chainBytes, err := base64.StdEncoding.DecodeString(chainString) + assert.NoError(t, err) + chainCert, err := x509.ParseCertificate(chainBytes) + assert.NoError(t, err) + // Check this matches the actual CA we setup. + assert.True(t, chainCert.Equal(spiffeCACert)) + } else { + require.Empty(t, chainString) + } - // Validate the authorization header exists. We rely on the AWS SDK to - // actually produce the signature, and, validating this signature would - // introduce significant complexity to this test - so this is omitted. - authz := r.Header.Get("Authorization") - assert.NotEmpty(t, authz) + // Validate the authorization header exists. We rely on the AWS SDK to + // actually produce the signature, and, validating this signature would + // introduce significant complexity to this test - so this is omitted. + authz := r.Header.Get("Authorization") + assert.NotEmpty(t, authz) - // Send mocked response - _, _ = w.Write([]byte(`{ + // Send mocked response + _, _ = w.Write([]byte(`{ "credentialSet":[ { "assumedRoleUser": { @@ -206,81 +256,83 @@ func TestBotWorkloadIdentityAWSRA(t *testing.T) { ], "subjectArn": "arn:aws:rolesanywhere:us-east-1:000000000000:subject/41cl0bae-6783-40d4-ab20-65dc5d922e45" }`)) - })) - t.Cleanup(srv.Close) + })) + t.Cleanup(srv.Close) - role, err := types.NewRole("issue-foo", types.RoleSpecV6{ - Allow: types.RoleConditions{ - WorkloadIdentityLabels: map[string]apiutils.Strings{ - "foo": []string{"bar"}, - }, - Rules: []types.Rule{ - { - Resources: []string{types.KindWorkloadIdentity}, - Verbs: []string{types.VerbRead, types.VerbList}, + role, err := types.NewRole("issue-foo", types.RoleSpecV6{ + Allow: types.RoleConditions{ + WorkloadIdentityLabels: map[string]apiutils.Strings{ + "foo": []string{"bar"}, + }, + Rules: []types.Rule{ + { + Resources: []string{types.KindWorkloadIdentity}, + Verbs: []string{types.VerbRead, types.VerbList}, + }, + }, }, - }, - }, - }) - require.NoError(t, err) - role, err = rootClient.UpsertRole(ctx, role) - require.NoError(t, err) + }) + require.NoError(t, err) + role, err = rootClient.UpsertRole(ctx, role) + require.NoError(t, err) - workloadIdentity := &workloadidentityv1pb.WorkloadIdentity{ - Kind: types.KindWorkloadIdentity, - Version: types.V1, - Metadata: &headerv1.Metadata{ - Name: "foo-bar-bizz", - Labels: map[string]string{ - "foo": "bar", - }, - }, - Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ - Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ - Id: "/ra-test", - }, - }, - } - workloadIdentity, err = rootClient.WorkloadIdentityResourceServiceClient(). - CreateWorkloadIdentity(ctx, &workloadidentityv1pb.CreateWorkloadIdentityRequest{ - WorkloadIdentity: workloadIdentity, - }) - require.NoError(t, err) + workloadIdentity := &workloadidentityv1pb.WorkloadIdentity{ + Kind: types.KindWorkloadIdentity, + Version: types.V1, + Metadata: &headerv1.Metadata{ + Name: "foo-bar-bizz", + Labels: map[string]string{ + "foo": "bar", + }, + }, + Spec: &workloadidentityv1pb.WorkloadIdentitySpec{ + Spiffe: &workloadidentityv1pb.WorkloadIdentitySPIFFE{ + Id: "/ra-test", + }, + }, + } + workloadIdentity, err = rootClient.WorkloadIdentityResourceServiceClient(). + CreateWorkloadIdentity(ctx, &workloadidentityv1pb.CreateWorkloadIdentityRequest{ + WorkloadIdentity: workloadIdentity, + }) + require.NoError(t, err) - tmpDir := t.TempDir() - onboarding, _ := makeBot(t, rootClient, "ra-test", role.GetName()) - botConfig := defaultBotConfig(t, process, onboarding, config.ServiceConfigs{ - &config.WorkloadIdentityAWSRAService{ - Selector: config.WorkloadIdentitySelector{ - Name: workloadIdentity.GetMetadata().GetName(), - }, - Destination: &config.DestinationDirectory{ - Path: tmpDir, - }, - RoleARN: roleArn, - ProfileARN: profileArn, - TrustAnchorARN: trustAnchorArn, - Region: "us-east-1", - SessionDuration: 2 * time.Hour, - SessionRenewalInterval: 30 * time.Minute, - EndpointOverride: srv.URL, - }, - }, defaultBotConfigOpts{ - useAuthServer: true, - insecure: true, - }) + tmpDir := t.TempDir() + onboarding, _ := makeBot(t, rootClient, "ra-test", role.GetName()) + botConfig := defaultBotConfig(t, process, onboarding, config.ServiceConfigs{ + &config.WorkloadIdentityAWSRAService{ + Selector: config.WorkloadIdentitySelector{ + Name: workloadIdentity.GetMetadata().GetName(), + }, + Destination: &config.DestinationDirectory{ + Path: tmpDir, + }, + RoleARN: roleArn, + ProfileARN: profileArn, + TrustAnchorARN: trustAnchorArn, + Region: "us-east-1", + SessionDuration: 2 * time.Hour, + SessionRenewalInterval: 30 * time.Minute, + EndpointOverride: srv.URL, + }, + }, defaultBotConfigOpts{ + useAuthServer: true, + insecure: true, + }) - botConfig.Oneshot = true - b := New(botConfig, log) - // Run Bot with 10 second timeout to catch hangs. - ctx, cancel := context.WithTimeout(ctx, time.Second*10) - t.Cleanup(cancel) - require.NoError(t, b.Run(ctx)) + botConfig.Oneshot = true + b := New(botConfig, log) + // Run Bot with 10 second timeout to catch hangs. + ctx, cancel := context.WithTimeout(ctx, time.Second*10) + t.Cleanup(cancel) + require.NoError(t, b.Run(ctx)) - got, err := os.ReadFile(filepath.Join(tmpDir, "aws_credentials")) - require.NoError(t, err) - if golden.ShouldSet() { - golden.Set(t, got) + got, err := os.ReadFile(filepath.Join(tmpDir, "aws_credentials")) + require.NoError(t, err) + if golden.ShouldSet() { + golden.Set(t, got) + } + require.Equal(t, string(golden.Get(t)), string(got)) + }) } - require.Equal(t, string(golden.Get(t)), string(got)) } diff --git a/lib/tbot/testdata/TestBotWorkloadIdentityAWSRA.golden b/lib/tbot/testdata/TestBotWorkloadIdentityAWSRA/external_pki.golden similarity index 100% rename from lib/tbot/testdata/TestBotWorkloadIdentityAWSRA.golden rename to lib/tbot/testdata/TestBotWorkloadIdentityAWSRA/external_pki.golden diff --git a/lib/tbot/testdata/TestBotWorkloadIdentityAWSRA/no_external_pki.golden b/lib/tbot/testdata/TestBotWorkloadIdentityAWSRA/no_external_pki.golden new file mode 100644 index 0000000000000..81841c661dfb5 --- /dev/null +++ b/lib/tbot/testdata/TestBotWorkloadIdentityAWSRA/no_external_pki.golden @@ -0,0 +1,5 @@ +[default] +aws_secret_access_key=secretAccessKey +aws_access_key_id=accessKeyId +aws_session_token=sessionToken +expiration=1848285415000