From 4fe6016388ef5e361764ec1ca4723cf403a47e23 Mon Sep 17 00:00:00 2001 From: Forrest Marshall Date: Thu, 21 Dec 2023 15:58:24 +0000 Subject: [PATCH 1/2] fix saml validation --- .../local/saml_idp_service_provider.go | 60 ++++----- .../local/saml_idp_service_provider_test.go | 117 ----------------- lib/services/saml_idp_service_provider.go | 59 ++++++++- .../saml_idp_service_provider_test.go | 118 ++++++++++++++++++ tool/tctl/common/resource_command.go | 13 ++ 5 files changed, 214 insertions(+), 153 deletions(-) diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index b93911b8281fa..2d4a99750c9b1 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -19,8 +19,6 @@ package local import ( "context" "encoding/xml" - "errors" - "io" "net/http" "net/url" "time" @@ -123,8 +121,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context } } - if err := validateSAMLIdPServiceProvider(sp); err != nil { - return trace.Wrap(err) + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) + } + + if ed.EntityID != sp.GetEntityID() { + return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") + } + + // ensure any filtering related issues get logged + if err := services.FilterSAMLEntityDescriptor(ed); err != nil { + s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err) } // embed attribute mapping in entity descriptor @@ -153,8 +162,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context // UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - if err := validateSAMLIdPServiceProvider(sp); err != nil { - return trace.Wrap(err) + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) + } + + if ed.EntityID != sp.GetEntityID() { + return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") + } + + // ensure any filtering related issues get logged + if err := services.FilterSAMLEntityDescriptor(ed); err != nil { + s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err) } // embed attribute mapping in entity descriptor @@ -220,34 +240,6 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte return nil } -// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID -// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints. -func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error { - ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) - if err != nil { - switch { - case errors.Is(err, io.EOF): - return trace.BadParameter("missing entity descriptor: %s", err.Error()) - default: - return trace.BadParameter(err.Error()) - } - } - - if ed.EntityID != sp.GetEntityID() { - return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") - } - - for _, descriptor := range ed.SPSSODescriptors { - for _, acs := range descriptor.AssertionConsumerServices { - if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil { - return trace.Wrap(err) - } - } - } - - return nil -} - // fetchAndSetEntityDescriptor fetches Service Provider entity descriptor (aka SP metadata) // from remote metadata endpoint (Entity ID) and sets it to sp if the xml format // is a valid Service Provider metadata format. diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index d854284c06530..4ed3f69cc914e 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -28,10 +28,8 @@ import ( "github.com/crewjam/saml/samlsp" "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -71,34 +69,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { }) require.NoError(t, err) - // Try to create a service provider with an invalid acs. - invalidSP, err := types.NewSAMLIdPServiceProvider( - types.Metadata{ - Name: "sp3", - }, - types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newInvalidACSEntityDescriptor("sp1"), - EntityID: "sp1", - }) - require.NoError(t, err) - err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - - // Try to create a service provider with a http acs. - invalidSP, err = types.NewSAMLIdPServiceProvider( - types.Metadata{ - Name: "sp3", - }, - types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newHTTPACSEntityDescriptor("sp1"), - EntityID: "sp1", - }) - require.NoError(t, err) - err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - // Initially we expect no service providers. out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "") require.NoError(t, err) @@ -198,20 +168,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { err = service.UpdateSAMLIdPServiceProvider(ctx, sp) require.Error(t, err) - // Update a service provider with an invalid acs. - invalidSP, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName()) - require.NoError(t, err) - invalidSP.SetEntityDescriptor(newInvalidACSEntityDescriptor(invalidSP.GetEntityID())) - err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - - // Update a service provider with a http acs. - invalidSP.SetEntityDescriptor(newHTTPACSEntityDescriptor(invalidSP.GetEntityID())) - err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "UpdateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - // Delete a service provider. err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName()) require.NoError(t, err) @@ -235,49 +191,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { require.Empty(t, out) } -func TestValidateSAMLIdPServiceProvider(t *testing.T) { - descriptor := newEntityDescriptor("IAMShowcase") - - cases := []struct { - name string - spec types.SAMLIdPServiceProviderSpecV1 - errAssertion require.ErrorAssertionFunc - }{ - { - name: "valid provider", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: descriptor, - EntityID: "IAMShowcase", - }, - errAssertion: require.NoError, - }, - { - name: "invalid entity id", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: descriptor, - EntityID: uuid.NewString(), - }, - errAssertion: require.Error, - }, - { - name: "invalid acs", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"), - EntityID: "IAMShowcase", - }, - errAssertion: require.Error, - }, - } - - for _, test := range cases { - t.Run(test.name, func(t *testing.T) { - sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec) - require.NoError(t, err) - test.errAssertion(t, validateSAMLIdPServiceProvider(sp)) - }) - } -} - func newEntityDescriptor(entityID string) string { return fmt.Sprintf(testEntityDescriptor, entityID) } @@ -293,36 +206,6 @@ const testEntityDescriptor = ` ` -func newInvalidACSEntityDescriptor(entityID string) string { - return fmt.Sprintf(invalidEntityDescriptor, entityID) -} - -// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations. -const invalidEntityDescriptor = ` - - - urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress - - - -` - -func newHTTPACSEntityDescriptor(entityID string) string { - return fmt.Sprintf(httpEntityDescriptor, entityID) -} - -// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with a http ACS locations. -const httpEntityDescriptor = ` - - - urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress - - - -` - func newSAMLSPMetadata(entityID, acsURL string) string { return fmt.Sprintf(samlSPMetadata, entityID, acsURL, acsURL) } diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go index 8f0bd379b8c35..4efb00e1ee5df 100644 --- a/lib/services/saml_idp_service_provider.go +++ b/lib/services/saml_idp_service_provider.go @@ -19,8 +19,11 @@ package services import ( "context" "net/url" + "slices" + "github.com/crewjam/saml" "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/utils" @@ -125,15 +128,67 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string) return &s, nil } +// supportedACSBindings is the set of AssertionConsumerService bindings that teleport supports. +var supportedACSBindings = map[string]struct{}{ + saml.HTTPPostBinding: {}, + saml.HTTPRedirectBinding: {}, +} + +// ValidateAssertionConsumerService checks if a given assertion consumer service is usable by teleport. Note that +// it is permissible for a service provider to include acs endpoints that are not compatible with teleport, so long +// as at least one _is_ compatible. +func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error { + if _, ok := supportedACSBindings[acs.Binding]; !ok { + return trace.BadParameter("unsupported acs binding: %q", acs.Binding) + } + + if acs.Location == "" { + return trace.BadParameter("acs location endpoint is missing or could not be decoded for %q binding", acs.Binding) + } + + return trace.Wrap(ValidateAssertionConsumerServicesEndpoint(acs.Location)) +} + +// FilterSAMLEntityDescriptor performs a filter in place to remove unsupported and/or insecure fields from +// a saml entity descriptor. Specifically, it removes acs endpoints that are either of an unsupported kind, +// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally +// expected that a service provider will successfully support a given ACS so long as they have at least one +// compatible binding. +func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor) error { + var originalCount int + var filteredCount int + for i := range ed.SPSSODescriptors { + filtered := slices.DeleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool { + if err := ValidateAssertionConsumerService(acs); err != nil { + log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err) + return true + } + + return false + }) + + originalCount += len(ed.SPSSODescriptors[i].AssertionConsumerServices) + filteredCount += len(filtered) + + ed.SPSSODescriptors[i].AssertionConsumerServices = filtered + } + + if filteredCount == 0 && originalCount != 0 { + return trace.BadParameter("no AssertionConsumerService bindings for entity %q passed validation", ed.EntityID) + } + + return nil +} + // ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location // is a valid HTTPS endpoint. func ValidateAssertionConsumerServicesEndpoint(acs string) error { endpoint, err := url.Parse(acs) switch { case err != nil: - return trace.Wrap(err) + return trace.BadParameter("acs location endpoint %q could not be parsed: %v", acs, err) case endpoint.Scheme != "https": - return trace.BadParameter("the assertion consumer services location must be an https endpoint") + return trace.BadParameter("invalid scheme %q in acs location endpoint %q (must be 'https')", endpoint.Scheme, acs) } return nil diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index 51987205e1c3d..cd53658081f80 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -17,8 +17,12 @@ limitations under the License. package services import ( + "fmt" + "strings" "testing" + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -61,6 +65,82 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) { require.Equal(t, expected, actual) } +func TestFilterSAMLEntityDescriptor(t *testing.T) { + tts := []struct { + eds string + ok bool + before, after int + name string + }{ + { + eds: edBuilder(). + ACS(saml.HTTPPostBinding, "https://one.example.com/acs"). + ACS(saml.HTTPPostBinding, "https://two.example.com/acs"). + Done(), + ok: true, + before: 2, + after: 2, + name: "no filtering", + }, + { + eds: edBuilder(). + ACS(saml.HTTPPostBinding, "https://example.com/acs"). + ACS(saml.HTTPPostBinding, "http://example.com/acs"). + Done(), + ok: true, + before: 2, + after: 1, + name: "scheme filtering", + }, + { + eds: edBuilder(). + ACS(saml.HTTPArtifactBinding, "https://example.com/acs"). + ACS(saml.HTTPPostBinding, "https://example.com/acs"). + Done(), + ok: true, + before: 2, + after: 1, + name: "binding filtering", + }, + { + eds: edBuilder(). + ACS(saml.HTTPArtifactBinding, "https://example.com/acs"). + ACS(saml.HTTPPostBinding, "http://example.com/acs"). + Done(), + ok: false, + before: 2, + after: 0, + name: "all invalid", + }, + } + + for _, tt := range tts { + t.Run(tt.name, func(t *testing.T) { + ed, err := samlsp.ParseMetadata([]byte(tt.eds)) + require.NoError(t, err) + + require.Equal(t, tt.before, getACSCount(ed)) + + err = FilterSAMLEntityDescriptor(ed) + if !tt.ok { + require.Error(t, err) + return + } + require.NoError(t, err) + + require.Equal(t, tt.after, getACSCount(ed)) + }) + } +} + +func getACSCount(ed *saml.EntityDescriptor) int { + var count int + for _, desc := range ed.SPSSODescriptors { + count += len(desc.AssertionConsumerServices) + } + return count +} + func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) { cases := []struct { location string @@ -116,3 +196,41 @@ const testEntityDescriptor = ` ` + +const edBuilderTemplate = ` + + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + %s + + +` + +const edBuilderACSTemplate = `` + +type entityDescriptorBuilder struct { + acs []struct { + binding, location string + } +} + +func edBuilder() *entityDescriptorBuilder { + return &entityDescriptorBuilder{} +} + +func (b *entityDescriptorBuilder) ACS(binding, location string) *entityDescriptorBuilder { + b.acs = append(b.acs, struct { + binding, location string + }{binding, location}) + return b +} + +func (b *entityDescriptorBuilder) Done() string { + var acss []string + for i, acs := range b.acs { + acss = append(acss, fmt.Sprintf(edBuilderACSTemplate, acs.binding, acs.location, i)) + } + + return fmt.Sprintf(edBuilderTemplate, strings.Join(acss, "\n ")) +} diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 84a55dce1f1ff..7b539ce5fe75a 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -28,6 +28,7 @@ import ( "time" "github.com/alecthomas/kingpin/v2" + "github.com/crewjam/saml/samlsp" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" log "github.com/sirupsen/logrus" @@ -857,6 +858,18 @@ func (rc *ResourceCommand) createSAMLIdPServiceProvider(ctx context.Context, cli return trace.Wrap(err) } + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) + } + + // try filtering the entity descriptor. if it can't be filtered down to a useable looking state, reject + // the creation attempt. + if err := services.FilterSAMLEntityDescriptor(ed); err != nil { + return trace.Wrap(err) + } + serviceProviderName := sp.GetName() if err := sp.CheckAndSetDefaults(); err != nil { return trace.Wrap(err) From 68b667085ec7717b6b5b17457d5e53a4e65f6a4d Mon Sep 17 00:00:00 2001 From: Sakshyam Shah Date: Fri, 12 Jan 2024 12:43:29 -0500 Subject: [PATCH 2/2] relax client side entity descriptor validation (#36602) * - FilterSAMLEntityDescriptor only if entity descriptor is not empty. - only issue warning on unsupported acs bindings. * include HTTPArtifactBinding in supported acs bindings * update TestFilterSAMLEntityDescriptor * exclude HTTPArtifactBinding and filter them from generated entity descriptor * remove HTTPArtifactBinding from test entity descriptor --- .../local/saml_idp_service_provider.go | 23 ++++++++++++------- .../local/saml_idp_service_provider_test.go | 3 +-- lib/services/saml_idp_service_provider.go | 6 +++-- .../saml_idp_service_provider_test.go | 7 +++--- tool/tctl/common/resource_command.go | 19 +++++++-------- 5 files changed, 34 insertions(+), 24 deletions(-) diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index 2d4a99750c9b1..d005f0735b1c7 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -124,7 +124,7 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context // verify that entity descriptor parses ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) if err != nil { - return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) } if ed.EntityID != sp.GetEntityID() { @@ -132,8 +132,8 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context } // ensure any filtering related issues get logged - if err := services.FilterSAMLEntityDescriptor(ed); err != nil { - s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err) + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) } // embed attribute mapping in entity descriptor @@ -165,7 +165,7 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context // verify that entity descriptor parses ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) if err != nil { - return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) } if ed.EntityID != sp.GetEntityID() { @@ -173,8 +173,8 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context } // ensure any filtering related issues get logged - if err := services.FilterSAMLEntityDescriptor(ed); err != nil { - s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err) + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) } // embed attribute mapping in entity descriptor @@ -288,12 +288,19 @@ func (s *SAMLIdPServiceProviderService) generateAndSetEntityDescriptor(sp types. AuthnNameIDFormat: saml.UnspecifiedNameIDFormat, } - entityDescriptor, err := xml.MarshalIndent(newServiceProvider.Metadata(), "", " ") + ed := newServiceProvider.Metadata() + // HTTPArtifactBinding is defined when entity descriptor is generated + // using crewjam/saml https://github.com/crewjam/saml/blob/main/service_provider.go#L228. + // But we do not support it, so filter it out below. + // Error and warnings are swallowed because the descriptor is Teleport generated and + // users have no control over sanitizing filtered binding. + services.FilterSAMLEntityDescriptor(ed, true /* quiet */) + edXMLBytes, err := xml.MarshalIndent(ed, "", " ") if err != nil { return trace.Wrap(err) } - sp.SetEntityDescriptor(string(entityDescriptor)) + sp.SetEntityDescriptor(string(edXMLBytes)) return nil } diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index 4ed3f69cc914e..ef5c33dcb195f 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -207,7 +207,7 @@ const testEntityDescriptor = ` ` func newSAMLSPMetadata(entityID, acsURL string) string { - return fmt.Sprintf(samlSPMetadata, entityID, acsURL, acsURL) + return fmt.Sprintf(samlSPMetadata, entityID, acsURL) } // samlSPMetadata mimics metadata generated by saml.ServiceProvider.Metadata() @@ -215,7 +215,6 @@ const samlSPMetadata = ` urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - ` diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go index 4efb00e1ee5df..3f288ad8fa748 100644 --- a/lib/services/saml_idp_service_provider.go +++ b/lib/services/saml_idp_service_provider.go @@ -154,13 +154,15 @@ func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error { // or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally // expected that a service provider will successfully support a given ACS so long as they have at least one // compatible binding. -func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor) error { +func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error { var originalCount int var filteredCount int for i := range ed.SPSSODescriptors { filtered := slices.DeleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool { if err := ValidateAssertionConsumerService(acs); err != nil { - log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err) + if !quiet { + log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err) + } return true } diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index cd53658081f80..50c88ba444f49 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -95,16 +95,17 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) { { eds: edBuilder(). ACS(saml.HTTPArtifactBinding, "https://example.com/acs"). + ACS("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST-SimpleSign", "https://example.com/POST-SimpleSign"). ACS(saml.HTTPPostBinding, "https://example.com/acs"). Done(), ok: true, - before: 2, + before: 3, after: 1, name: "binding filtering", }, { eds: edBuilder(). - ACS(saml.HTTPArtifactBinding, "https://example.com/acs"). + ACS("urn:oasis:names:tc:SAML:2.0:bindings:PAOS", "https://example.com/ECP"). ACS(saml.HTTPPostBinding, "http://example.com/acs"). Done(), ok: false, @@ -121,7 +122,7 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) { require.Equal(t, tt.before, getACSCount(ed)) - err = FilterSAMLEntityDescriptor(ed) + err = FilterSAMLEntityDescriptor(ed, false /* quiet */) if !tt.ok { require.Error(t, err) return diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 7b539ce5fe75a..18778db081065 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -858,16 +858,17 @@ func (rc *ResourceCommand) createSAMLIdPServiceProvider(ctx context.Context, cli return trace.Wrap(err) } - // verify that entity descriptor parses - ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) - if err != nil { - return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err) - } + if sp.GetEntityDescriptor() != "" { + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) + } - // try filtering the entity descriptor. if it can't be filtered down to a useable looking state, reject - // the creation attempt. - if err := services.FilterSAMLEntityDescriptor(ed); err != nil { - return trace.Wrap(err) + // issue warning about unsupported ACS bindings. + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + log.Warnf("Entity descriptor for SAML IdP service provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) + } } serviceProviderName := sp.GetName()