diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index cbe090ce7202d..c6045e0e8e2a7 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -72,12 +72,12 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co // CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - item, err := s.svc.MakeBackendItem(sp, sp.GetName()) - if err != nil { + if err := validateSAMLIdPServiceProvider(sp); err != nil { return trace.Wrap(err) } - if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil { + item, err := s.svc.MakeBackendItem(sp, sp.GetName()) + if err != nil { return trace.Wrap(err) } @@ -87,7 +87,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - _, err = backend.Create(ctx, item) + _, err := backend.Create(ctx, item) if trace.IsAlreadyExists(err) { return trace.AlreadyExists("%s %q already exists", types.KindSAMLIdPServiceProvider, sp.GetName()) } @@ -97,12 +97,12 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context // UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - item, err := s.svc.MakeBackendItem(sp, sp.GetName()) - if err != nil { + if err := validateSAMLIdPServiceProvider(sp); err != nil { return trace.Wrap(err) } - if err := s.ensureEntityDescriptorMatchesEntityID(sp); err != nil { + item, err := s.svc.MakeBackendItem(sp, sp.GetName()) + if err != nil { return trace.Wrap(err) } @@ -112,7 +112,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context return trace.Wrap(err) } - _, err = backend.Update(ctx, item) + _, err := backend.Update(ctx, item) if trace.IsNotFound(err) { return trace.NotFound("%s %q doesn't exist", types.KindSAMLIdPServiceProvider, sp.GetName()) } @@ -159,9 +159,9 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte return nil } -// ensureEntityDescriptorMatchesEntityID ensures that the entity ID in the entity descriptor is the same as the entity ID -// in the SAMLIdPServiceProvider object. -func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp types.SAMLIdPServiceProvider) error { +// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID +// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints. +func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error { ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) if err != nil { return trace.Wrap(err) @@ -171,5 +171,13 @@ func (s *SAMLIdPServiceProviderService) ensureEntityDescriptorMatchesEntityID(sp return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") } + for _, descriptor := range ed.SPSSODescriptors { + for _, acs := range descriptor.AssertionConsumerServices { + if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil { + return trace.Wrap(err) + } + } + } + return nil } diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index 86a66cbdda2c9..4a90318e8bd77 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -23,6 +23,7 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" + "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" "github.com/stretchr/testify/require" @@ -64,6 +65,20 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { }) require.NoError(t, err) + // Try to create an invalid service provider with an invalid acs. + sp3, err := types.NewSAMLIdPServiceProvider( + types.Metadata{ + Name: "sp3", + }, + types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: newInvalidACSEntityDescriptor("sp1"), + EntityID: "sp1", + }) + require.NoError(t, err) + err = service.CreateSAMLIdPServiceProvider(ctx, sp3) + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + // Initially we expect no service providers. out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "") require.NoError(t, err) @@ -163,6 +178,14 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { err = service.UpdateSAMLIdPServiceProvider(ctx, sp) require.Error(t, err) + // Update a service provider with an invalid acs. + sp, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName()) + require.NoError(t, err) + sp.SetEntityDescriptor(newInvalidACSEntityDescriptor(sp1.GetEntityID())) + err = service.UpdateSAMLIdPServiceProvider(ctx, sp) + require.Error(t, err) + require.True(t, trace.IsBadParameter(err)) + // Delete a service provider. err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName()) require.NoError(t, err) @@ -186,6 +209,49 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { require.Empty(t, out) } +func TestValidateSAMLIdPServiceProvider(t *testing.T) { + descriptor := newEntityDescriptor("IAMShowcase") + + cases := []struct { + name string + spec types.SAMLIdPServiceProviderSpecV1 + errAssertion require.ErrorAssertionFunc + }{ + { + name: "valid provider", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: descriptor, + EntityID: "IAMShowcase", + }, + errAssertion: require.NoError, + }, + { + name: "invalid entity id", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: descriptor, + EntityID: uuid.NewString(), + }, + errAssertion: require.Error, + }, + { + name: "invalid acs", + spec: types.SAMLIdPServiceProviderSpecV1{ + EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"), + EntityID: "IAMShowcase", + }, + errAssertion: require.Error, + }, + } + + for _, test := range cases { + t.Run(test.name, func(t *testing.T) { + sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec) + require.NoError(t, err) + test.errAssertion(t, validateSAMLIdPServiceProvider(sp)) + }) + } +} + func newEntityDescriptor(entityID string) string { return fmt.Sprintf(testEntityDescriptor, entityID) } @@ -200,3 +266,18 @@ const testEntityDescriptor = ` ` + +func newInvalidACSEntityDescriptor(entityID string) string { + return fmt.Sprintf(invalidEntityDescriptor, entityID) +} + +// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations. +const invalidEntityDescriptor = ` + + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + + + +` diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go index 31f76648688b0..dcd750a5008f6 100644 --- a/lib/services/saml_idp_service_provider.go +++ b/lib/services/saml_idp_service_provider.go @@ -18,6 +18,7 @@ package services import ( "context" + "net/url" "github.com/gravitational/trace" @@ -25,7 +26,7 @@ import ( "github.com/gravitational/teleport/lib/utils" ) -// SAMLIdPServiceProvider defines an interface for managing SAML IdP service providers. +// SAMLIdPServiceProviders defines an interface for managing SAML IdP service providers. type SAMLIdPServiceProviders interface { // ListSAMLIdPServiceProviders returns a paginated list of all SAML IdP service provider resources. ListSAMLIdPServiceProviders(context.Context, int, string) ([]types.SAMLIdPServiceProvider, string, error) @@ -115,3 +116,17 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string) } return &s, nil } + +// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location +// is a valid HTTPS endpoint. +func ValidateAssertionConsumerServicesEndpoint(acs string) error { + endpoint, err := url.Parse(acs) + switch { + case err != nil: + return trace.Wrap(err) + case endpoint.Scheme != "https": + return trace.BadParameter("the assertion consumer services location must be an https endpoint") + } + + return nil +} diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index 25f83c60d95ce..51987205e1c3d 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -61,6 +61,32 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) { require.Equal(t, expected, actual) } +func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) { + cases := []struct { + location string + assertion require.ErrorAssertionFunc + }{ + { + location: "https://sptest.iamshowcase.com/acs", + assertion: require.NoError, + }, + { + location: "http://sptest.iamshowcase.com/acs", + assertion: require.Error, + }, + { + location: "javascript://sptest.iamshowcase.com/acs", + assertion: require.Error, + }, + } + + for _, test := range cases { + t.Run(test.location, func(t *testing.T) { + test.assertion(t, ValidateAssertionConsumerServicesEndpoint(test.location)) + }) + } +} + var samlIDPServiceProviderYAML = `--- kind: saml_idp_service_provider version: v1