diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index 69f8a46ca3a8b..db4e393efa8e1 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -22,6 +22,7 @@ import ( "github.com/crewjam/saml/samlsp" "github.com/gravitational/trace" + "github.com/sirupsen/logrus" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/backend" @@ -72,8 +73,18 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co // CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - if err := validateSAMLIdPServiceProvider(sp); err != nil { - return trace.Wrap(err) + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) + } + + if ed.EntityID != sp.GetEntityID() { + return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") + } + + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) } item, err := s.svc.MakeBackendItem(sp, sp.GetName()) @@ -97,8 +108,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context // UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { - if err := validateSAMLIdPServiceProvider(sp); err != nil { - return trace.Wrap(err) + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) + } + + if ed.EntityID != sp.GetEntityID() { + return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") + } + + // ensure any filtering related issues get logged + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) } item, err := s.svc.MakeBackendItem(sp, sp.GetName()) @@ -158,26 +180,3 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte return nil } - -// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID -// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints. -func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error { - ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) - if err != nil { - return trace.BadParameter(err.Error()) - } - - if ed.EntityID != sp.GetEntityID() { - return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object") - } - - for _, descriptor := range ed.SPSSODescriptors { - for _, acs := range descriptor.AssertionConsumerServices { - if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil { - return trace.Wrap(err) - } - } - } - - return nil -} diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index 80d5cd3dfd100..86a66cbdda2c9 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -23,10 +23,8 @@ import ( "github.com/google/go-cmp/cmp" "github.com/google/go-cmp/cmp/cmpopts" - "github.com/google/uuid" "github.com/gravitational/trace" "github.com/jonboulle/clockwork" - "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -66,34 +64,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { }) require.NoError(t, err) - // Try to create a service provider with an invalid acs. - invalidSP, err := types.NewSAMLIdPServiceProvider( - types.Metadata{ - Name: "sp3", - }, - types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newInvalidACSEntityDescriptor("sp1"), - EntityID: "sp1", - }) - require.NoError(t, err) - err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - - // Try to create a service provider with a http acs. - invalidSP, err = types.NewSAMLIdPServiceProvider( - types.Metadata{ - Name: "sp3", - }, - types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newHTTPACSEntityDescriptor("sp1"), - EntityID: "sp1", - }) - require.NoError(t, err) - err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - // Initially we expect no service providers. out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "") require.NoError(t, err) @@ -193,20 +163,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { err = service.UpdateSAMLIdPServiceProvider(ctx, sp) require.Error(t, err) - // Update a service provider with an invalid acs. - invalidSP, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName()) - require.NoError(t, err) - invalidSP.SetEntityDescriptor(newInvalidACSEntityDescriptor(invalidSP.GetEntityID())) - err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - - // Update a service provider with a http acs. - invalidSP.SetEntityDescriptor(newHTTPACSEntityDescriptor(invalidSP.GetEntityID())) - err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP) - assert.Error(t, err) - require.True(t, trace.IsBadParameter(err), "UpdateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err)) - // Delete a service provider. err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName()) require.NoError(t, err) @@ -230,49 +186,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) { require.Empty(t, out) } -func TestValidateSAMLIdPServiceProvider(t *testing.T) { - descriptor := newEntityDescriptor("IAMShowcase") - - cases := []struct { - name string - spec types.SAMLIdPServiceProviderSpecV1 - errAssertion require.ErrorAssertionFunc - }{ - { - name: "valid provider", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: descriptor, - EntityID: "IAMShowcase", - }, - errAssertion: require.NoError, - }, - { - name: "invalid entity id", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: descriptor, - EntityID: uuid.NewString(), - }, - errAssertion: require.Error, - }, - { - name: "invalid acs", - spec: types.SAMLIdPServiceProviderSpecV1{ - EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"), - EntityID: "IAMShowcase", - }, - errAssertion: require.Error, - }, - } - - for _, test := range cases { - t.Run(test.name, func(t *testing.T) { - sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec) - require.NoError(t, err) - test.errAssertion(t, validateSAMLIdPServiceProvider(sp)) - }) - } -} - func newEntityDescriptor(entityID string) string { return fmt.Sprintf(testEntityDescriptor, entityID) } @@ -287,33 +200,3 @@ const testEntityDescriptor = ` ` - -func newInvalidACSEntityDescriptor(entityID string) string { - return fmt.Sprintf(invalidEntityDescriptor, entityID) -} - -// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations. -const invalidEntityDescriptor = ` - - - urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress - - - -` - -func newHTTPACSEntityDescriptor(entityID string) string { - return fmt.Sprintf(httpEntityDescriptor, entityID) -} - -// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with a http ACS locations. -const httpEntityDescriptor = ` - - - urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified - urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress - - - -` diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go index dcd750a5008f6..dac9898a5bcc6 100644 --- a/lib/services/saml_idp_service_provider.go +++ b/lib/services/saml_idp_service_provider.go @@ -20,7 +20,10 @@ import ( "context" "net/url" + "github.com/crewjam/saml" "github.com/gravitational/trace" + log "github.com/sirupsen/logrus" + "golang.org/x/exp/slices" "github.com/gravitational/teleport/api/types" "github.com/gravitational/teleport/lib/utils" @@ -117,15 +120,84 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string) return &s, nil } +// supportedACSBindings is the set of AssertionConsumerService bindings that teleport supports. +var supportedACSBindings = map[string]struct{}{ + saml.HTTPPostBinding: {}, + saml.HTTPRedirectBinding: {}, +} + +// ValidateAssertionConsumerService checks if a given assertion consumer service is usable by teleport. Note that +// it is permissible for a service provider to include acs endpoints that are not compatible with teleport, so long +// as at least one _is_ compatible. +func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error { + if _, ok := supportedACSBindings[acs.Binding]; !ok { + return trace.BadParameter("unsupported acs binding: %q", acs.Binding) + } + + if acs.Location == "" { + return trace.BadParameter("acs location endpoint is missing or could not be decoded for %q binding", acs.Binding) + } + + return trace.Wrap(ValidateAssertionConsumerServicesEndpoint(acs.Location)) +} + +// FilterSAMLEntityDescriptor performs a filter in place to remove unsupported and/or insecure fields from +// a saml entity descriptor. Specifically, it removes acs endpoints that are either of an unsupported kind, +// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally +// expected that a service provider will successfully support a given ACS so long as they have at least one +// compatible binding. +func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error { + var originalCount int + var filteredCount int + for i := range ed.SPSSODescriptors { + filtered := deleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool { + if err := ValidateAssertionConsumerService(acs); err != nil { + if !quiet { + log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err) + } + return true + } + + return false + }) + + originalCount += len(ed.SPSSODescriptors[i].AssertionConsumerServices) + filteredCount += len(filtered) + + ed.SPSSODescriptors[i].AssertionConsumerServices = filtered + } + + if filteredCount == 0 && originalCount != 0 { + return trace.BadParameter("no AssertionConsumerService bindings for entity %q passed validation", ed.EntityID) + } + + return nil +} + +func deleteFunc[S ~[]E, E any](s S, del func(E) bool) S { + i := slices.IndexFunc(s, del) + if i == -1 { + return s + } + // Don't start copying elements until we find one to delete. + for j := i + 1; j < len(s); j++ { + if v := s[j]; !del(v) { + s[i] = v + i++ + } + } + return s[:i] +} + // ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location // is a valid HTTPS endpoint. func ValidateAssertionConsumerServicesEndpoint(acs string) error { endpoint, err := url.Parse(acs) switch { case err != nil: - return trace.Wrap(err) + return trace.BadParameter("acs location endpoint %q could not be parsed: %v", acs, err) case endpoint.Scheme != "https": - return trace.BadParameter("the assertion consumer services location must be an https endpoint") + return trace.BadParameter("invalid scheme %q in acs location endpoint %q (must be 'https')", endpoint.Scheme, acs) } return nil diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go index 51987205e1c3d..50c88ba444f49 100644 --- a/lib/services/saml_idp_service_provider_test.go +++ b/lib/services/saml_idp_service_provider_test.go @@ -17,8 +17,12 @@ limitations under the License. package services import ( + "fmt" + "strings" "testing" + "github.com/crewjam/saml" + "github.com/crewjam/saml/samlsp" "github.com/stretchr/testify/require" "github.com/gravitational/teleport/api/types" @@ -61,6 +65,83 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) { require.Equal(t, expected, actual) } +func TestFilterSAMLEntityDescriptor(t *testing.T) { + tts := []struct { + eds string + ok bool + before, after int + name string + }{ + { + eds: edBuilder(). + ACS(saml.HTTPPostBinding, "https://one.example.com/acs"). + ACS(saml.HTTPPostBinding, "https://two.example.com/acs"). + Done(), + ok: true, + before: 2, + after: 2, + name: "no filtering", + }, + { + eds: edBuilder(). + ACS(saml.HTTPPostBinding, "https://example.com/acs"). + ACS(saml.HTTPPostBinding, "http://example.com/acs"). + Done(), + ok: true, + before: 2, + after: 1, + name: "scheme filtering", + }, + { + eds: edBuilder(). + ACS(saml.HTTPArtifactBinding, "https://example.com/acs"). + ACS("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST-SimpleSign", "https://example.com/POST-SimpleSign"). + ACS(saml.HTTPPostBinding, "https://example.com/acs"). + Done(), + ok: true, + before: 3, + after: 1, + name: "binding filtering", + }, + { + eds: edBuilder(). + ACS("urn:oasis:names:tc:SAML:2.0:bindings:PAOS", "https://example.com/ECP"). + ACS(saml.HTTPPostBinding, "http://example.com/acs"). + Done(), + ok: false, + before: 2, + after: 0, + name: "all invalid", + }, + } + + for _, tt := range tts { + t.Run(tt.name, func(t *testing.T) { + ed, err := samlsp.ParseMetadata([]byte(tt.eds)) + require.NoError(t, err) + + require.Equal(t, tt.before, getACSCount(ed)) + + err = FilterSAMLEntityDescriptor(ed, false /* quiet */) + if !tt.ok { + require.Error(t, err) + return + } + require.NoError(t, err) + + require.Equal(t, tt.after, getACSCount(ed)) + }) + } +} + +func getACSCount(ed *saml.EntityDescriptor) int { + var count int + for _, desc := range ed.SPSSODescriptors { + count += len(desc.AssertionConsumerServices) + } + return count +} + func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) { cases := []struct { location string @@ -116,3 +197,41 @@ const testEntityDescriptor = ` ` + +const edBuilderTemplate = ` + + + urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified + urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress + %s + + +` + +const edBuilderACSTemplate = `` + +type entityDescriptorBuilder struct { + acs []struct { + binding, location string + } +} + +func edBuilder() *entityDescriptorBuilder { + return &entityDescriptorBuilder{} +} + +func (b *entityDescriptorBuilder) ACS(binding, location string) *entityDescriptorBuilder { + b.acs = append(b.acs, struct { + binding, location string + }{binding, location}) + return b +} + +func (b *entityDescriptorBuilder) Done() string { + var acss []string + for i, acs := range b.acs { + acss = append(acss, fmt.Sprintf(edBuilderACSTemplate, acs.binding, acs.location, i)) + } + + return fmt.Sprintf(edBuilderTemplate, strings.Join(acss, "\n ")) +} diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go index 6e175c8c9ce98..29dc8199b5d54 100644 --- a/tool/tctl/common/resource_command.go +++ b/tool/tctl/common/resource_command.go @@ -28,6 +28,7 @@ import ( "time" "github.com/alecthomas/kingpin/v2" + "github.com/crewjam/saml/samlsp" "github.com/gravitational/trace" "github.com/gravitational/trace/trail" log "github.com/sirupsen/logrus" @@ -852,6 +853,19 @@ func (rc *ResourceCommand) createSAMLIdPServiceProvider(ctx context.Context, cli return trace.Wrap(err) } + if sp.GetEntityDescriptor() != "" { + // verify that entity descriptor parses + ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor())) + if err != nil { + return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err) + } + + // issue warning about unsupported ACS bindings. + if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil { + log.Warnf("Entity descriptor for SAML IdP service provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err) + } + } + serviceProviderName := sp.GetName() if err := sp.CheckAndSetDefaults(); err != nil { return trace.Wrap(err)