diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go
index 69f8a46ca3a8b..db4e393efa8e1 100644
--- a/lib/services/local/saml_idp_service_provider.go
+++ b/lib/services/local/saml_idp_service_provider.go
@@ -22,6 +22,7 @@ import (
"github.com/crewjam/saml/samlsp"
"github.com/gravitational/trace"
+ "github.com/sirupsen/logrus"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend"
@@ -72,8 +73,18 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co
// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
- if err := validateSAMLIdPServiceProvider(sp); err != nil {
- return trace.Wrap(err)
+ // verify that entity descriptor parses
+ ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
+ if err != nil {
+ return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
+ }
+
+ if ed.EntityID != sp.GetEntityID() {
+ return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
+ }
+
+ if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
+ logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
@@ -97,8 +108,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
- if err := validateSAMLIdPServiceProvider(sp); err != nil {
- return trace.Wrap(err)
+ // verify that entity descriptor parses
+ ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
+ if err != nil {
+ return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
+ }
+
+ if ed.EntityID != sp.GetEntityID() {
+ return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
+ }
+
+ // ensure any filtering related issues get logged
+ if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
+ logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}
item, err := s.svc.MakeBackendItem(sp, sp.GetName())
@@ -158,26 +180,3 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte
return nil
}
-
-// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID
-// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints.
-func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error {
- ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
- if err != nil {
- return trace.BadParameter(err.Error())
- }
-
- if ed.EntityID != sp.GetEntityID() {
- return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
- }
-
- for _, descriptor := range ed.SPSSODescriptors {
- for _, acs := range descriptor.AssertionConsumerServices {
- if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil {
- return trace.Wrap(err)
- }
- }
- }
-
- return nil
-}
diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go
index 80d5cd3dfd100..86a66cbdda2c9 100644
--- a/lib/services/local/saml_idp_service_provider_test.go
+++ b/lib/services/local/saml_idp_service_provider_test.go
@@ -23,10 +23,8 @@ import (
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
- "github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
- "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
@@ -66,34 +64,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
})
require.NoError(t, err)
- // Try to create a service provider with an invalid acs.
- invalidSP, err := types.NewSAMLIdPServiceProvider(
- types.Metadata{
- Name: "sp3",
- },
- types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: newInvalidACSEntityDescriptor("sp1"),
- EntityID: "sp1",
- })
- require.NoError(t, err)
- err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
- assert.Error(t, err)
- require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))
-
- // Try to create a service provider with a http acs.
- invalidSP, err = types.NewSAMLIdPServiceProvider(
- types.Metadata{
- Name: "sp3",
- },
- types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: newHTTPACSEntityDescriptor("sp1"),
- EntityID: "sp1",
- })
- require.NoError(t, err)
- err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
- assert.Error(t, err)
- require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))
-
// Initially we expect no service providers.
out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "")
require.NoError(t, err)
@@ -193,20 +163,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)
- // Update a service provider with an invalid acs.
- invalidSP, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName())
- require.NoError(t, err)
- invalidSP.SetEntityDescriptor(newInvalidACSEntityDescriptor(invalidSP.GetEntityID()))
- err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
- assert.Error(t, err)
- require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))
-
- // Update a service provider with a http acs.
- invalidSP.SetEntityDescriptor(newHTTPACSEntityDescriptor(invalidSP.GetEntityID()))
- err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
- assert.Error(t, err)
- require.True(t, trace.IsBadParameter(err), "UpdateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))
-
// Delete a service provider.
err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
@@ -230,49 +186,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
require.Empty(t, out)
}
-func TestValidateSAMLIdPServiceProvider(t *testing.T) {
- descriptor := newEntityDescriptor("IAMShowcase")
-
- cases := []struct {
- name string
- spec types.SAMLIdPServiceProviderSpecV1
- errAssertion require.ErrorAssertionFunc
- }{
- {
- name: "valid provider",
- spec: types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: descriptor,
- EntityID: "IAMShowcase",
- },
- errAssertion: require.NoError,
- },
- {
- name: "invalid entity id",
- spec: types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: descriptor,
- EntityID: uuid.NewString(),
- },
- errAssertion: require.Error,
- },
- {
- name: "invalid acs",
- spec: types.SAMLIdPServiceProviderSpecV1{
- EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"),
- EntityID: "IAMShowcase",
- },
- errAssertion: require.Error,
- },
- }
-
- for _, test := range cases {
- t.Run(test.name, func(t *testing.T) {
- sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec)
- require.NoError(t, err)
- test.errAssertion(t, validateSAMLIdPServiceProvider(sp))
- })
- }
-}
-
func newEntityDescriptor(entityID string) string {
return fmt.Sprintf(testEntityDescriptor, entityID)
}
@@ -287,33 +200,3 @@ const testEntityDescriptor = `
`
-
-func newInvalidACSEntityDescriptor(entityID string) string {
- return fmt.Sprintf(invalidEntityDescriptor, entityID)
-}
-
-// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations.
-const invalidEntityDescriptor = `
-
-
- urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified
- urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress
-
-
-
-`
-
-func newHTTPACSEntityDescriptor(entityID string) string {
- return fmt.Sprintf(httpEntityDescriptor, entityID)
-}
-
-// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with a http ACS locations.
-const httpEntityDescriptor = `
-
-
- urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified
- urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress
-
-
-
-`
diff --git a/lib/services/saml_idp_service_provider.go b/lib/services/saml_idp_service_provider.go
index dcd750a5008f6..dac9898a5bcc6 100644
--- a/lib/services/saml_idp_service_provider.go
+++ b/lib/services/saml_idp_service_provider.go
@@ -20,7 +20,10 @@ import (
"context"
"net/url"
+ "github.com/crewjam/saml"
"github.com/gravitational/trace"
+ log "github.com/sirupsen/logrus"
+ "golang.org/x/exp/slices"
"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
@@ -117,15 +120,84 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string)
return &s, nil
}
+// supportedACSBindings is the set of AssertionConsumerService bindings that teleport supports.
+var supportedACSBindings = map[string]struct{}{
+ saml.HTTPPostBinding: {},
+ saml.HTTPRedirectBinding: {},
+}
+
+// ValidateAssertionConsumerService checks if a given assertion consumer service is usable by teleport. Note that
+// it is permissible for a service provider to include acs endpoints that are not compatible with teleport, so long
+// as at least one _is_ compatible.
+func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error {
+ if _, ok := supportedACSBindings[acs.Binding]; !ok {
+ return trace.BadParameter("unsupported acs binding: %q", acs.Binding)
+ }
+
+ if acs.Location == "" {
+ return trace.BadParameter("acs location endpoint is missing or could not be decoded for %q binding", acs.Binding)
+ }
+
+ return trace.Wrap(ValidateAssertionConsumerServicesEndpoint(acs.Location))
+}
+
+// FilterSAMLEntityDescriptor performs a filter in place to remove unsupported and/or insecure fields from
+// a saml entity descriptor. Specifically, it removes acs endpoints that are either of an unsupported kind,
+// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally
+// expected that a service provider will successfully support a given ACS so long as they have at least one
+// compatible binding.
+func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error {
+ var originalCount int
+ var filteredCount int
+ for i := range ed.SPSSODescriptors {
+ filtered := deleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool {
+ if err := ValidateAssertionConsumerService(acs); err != nil {
+ if !quiet {
+ log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err)
+ }
+ return true
+ }
+
+ return false
+ })
+
+ originalCount += len(ed.SPSSODescriptors[i].AssertionConsumerServices)
+ filteredCount += len(filtered)
+
+ ed.SPSSODescriptors[i].AssertionConsumerServices = filtered
+ }
+
+ if filteredCount == 0 && originalCount != 0 {
+ return trace.BadParameter("no AssertionConsumerService bindings for entity %q passed validation", ed.EntityID)
+ }
+
+ return nil
+}
+
+func deleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
+ i := slices.IndexFunc(s, del)
+ if i == -1 {
+ return s
+ }
+ // Don't start copying elements until we find one to delete.
+ for j := i + 1; j < len(s); j++ {
+ if v := s[j]; !del(v) {
+ s[i] = v
+ i++
+ }
+ }
+ return s[:i]
+}
+
// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location
// is a valid HTTPS endpoint.
func ValidateAssertionConsumerServicesEndpoint(acs string) error {
endpoint, err := url.Parse(acs)
switch {
case err != nil:
- return trace.Wrap(err)
+ return trace.BadParameter("acs location endpoint %q could not be parsed: %v", acs, err)
case endpoint.Scheme != "https":
- return trace.BadParameter("the assertion consumer services location must be an https endpoint")
+ return trace.BadParameter("invalid scheme %q in acs location endpoint %q (must be 'https')", endpoint.Scheme, acs)
}
return nil
diff --git a/lib/services/saml_idp_service_provider_test.go b/lib/services/saml_idp_service_provider_test.go
index 51987205e1c3d..50c88ba444f49 100644
--- a/lib/services/saml_idp_service_provider_test.go
+++ b/lib/services/saml_idp_service_provider_test.go
@@ -17,8 +17,12 @@ limitations under the License.
package services
import (
+ "fmt"
+ "strings"
"testing"
+ "github.com/crewjam/saml"
+ "github.com/crewjam/saml/samlsp"
"github.com/stretchr/testify/require"
"github.com/gravitational/teleport/api/types"
@@ -61,6 +65,83 @@ func TestSAMLIdPServiceProviderMarshal(t *testing.T) {
require.Equal(t, expected, actual)
}
+func TestFilterSAMLEntityDescriptor(t *testing.T) {
+ tts := []struct {
+ eds string
+ ok bool
+ before, after int
+ name string
+ }{
+ {
+ eds: edBuilder().
+ ACS(saml.HTTPPostBinding, "https://one.example.com/acs").
+ ACS(saml.HTTPPostBinding, "https://two.example.com/acs").
+ Done(),
+ ok: true,
+ before: 2,
+ after: 2,
+ name: "no filtering",
+ },
+ {
+ eds: edBuilder().
+ ACS(saml.HTTPPostBinding, "https://example.com/acs").
+ ACS(saml.HTTPPostBinding, "http://example.com/acs").
+ Done(),
+ ok: true,
+ before: 2,
+ after: 1,
+ name: "scheme filtering",
+ },
+ {
+ eds: edBuilder().
+ ACS(saml.HTTPArtifactBinding, "https://example.com/acs").
+ ACS("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST-SimpleSign", "https://example.com/POST-SimpleSign").
+ ACS(saml.HTTPPostBinding, "https://example.com/acs").
+ Done(),
+ ok: true,
+ before: 3,
+ after: 1,
+ name: "binding filtering",
+ },
+ {
+ eds: edBuilder().
+ ACS("urn:oasis:names:tc:SAML:2.0:bindings:PAOS", "https://example.com/ECP").
+ ACS(saml.HTTPPostBinding, "http://example.com/acs").
+ Done(),
+ ok: false,
+ before: 2,
+ after: 0,
+ name: "all invalid",
+ },
+ }
+
+ for _, tt := range tts {
+ t.Run(tt.name, func(t *testing.T) {
+ ed, err := samlsp.ParseMetadata([]byte(tt.eds))
+ require.NoError(t, err)
+
+ require.Equal(t, tt.before, getACSCount(ed))
+
+ err = FilterSAMLEntityDescriptor(ed, false /* quiet */)
+ if !tt.ok {
+ require.Error(t, err)
+ return
+ }
+ require.NoError(t, err)
+
+ require.Equal(t, tt.after, getACSCount(ed))
+ })
+ }
+}
+
+func getACSCount(ed *saml.EntityDescriptor) int {
+ var count int
+ for _, desc := range ed.SPSSODescriptors {
+ count += len(desc.AssertionConsumerServices)
+ }
+ return count
+}
+
func TestValidateAssertionConsumerServicesEndpoint(t *testing.T) {
cases := []struct {
location string
@@ -116,3 +197,41 @@ const testEntityDescriptor = `
`
+
+const edBuilderTemplate = `
+
+
+ urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified
+ urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress
+ %s
+
+
+`
+
+const edBuilderACSTemplate = ``
+
+type entityDescriptorBuilder struct {
+ acs []struct {
+ binding, location string
+ }
+}
+
+func edBuilder() *entityDescriptorBuilder {
+ return &entityDescriptorBuilder{}
+}
+
+func (b *entityDescriptorBuilder) ACS(binding, location string) *entityDescriptorBuilder {
+ b.acs = append(b.acs, struct {
+ binding, location string
+ }{binding, location})
+ return b
+}
+
+func (b *entityDescriptorBuilder) Done() string {
+ var acss []string
+ for i, acs := range b.acs {
+ acss = append(acss, fmt.Sprintf(edBuilderACSTemplate, acs.binding, acs.location, i))
+ }
+
+ return fmt.Sprintf(edBuilderTemplate, strings.Join(acss, "\n "))
+}
diff --git a/tool/tctl/common/resource_command.go b/tool/tctl/common/resource_command.go
index 6e175c8c9ce98..29dc8199b5d54 100644
--- a/tool/tctl/common/resource_command.go
+++ b/tool/tctl/common/resource_command.go
@@ -28,6 +28,7 @@ import (
"time"
"github.com/alecthomas/kingpin/v2"
+ "github.com/crewjam/saml/samlsp"
"github.com/gravitational/trace"
"github.com/gravitational/trace/trail"
log "github.com/sirupsen/logrus"
@@ -852,6 +853,19 @@ func (rc *ResourceCommand) createSAMLIdPServiceProvider(ctx context.Context, cli
return trace.Wrap(err)
}
+ if sp.GetEntityDescriptor() != "" {
+ // verify that entity descriptor parses
+ ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
+ if err != nil {
+ return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
+ }
+
+ // issue warning about unsupported ACS bindings.
+ if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
+ log.Warnf("Entity descriptor for SAML IdP service provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
+ }
+ }
+
serviceProviderName := sp.GetName()
if err := sp.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)