Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
53 changes: 26 additions & 27 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (

"github.com/crewjam/saml/samlsp"
"github.com/gravitational/trace"
"github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/backend"
Expand Down Expand Up @@ -72,8 +73,18 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co

// CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

item, err := s.svc.MakeBackendItem(sp, sp.GetName())
Expand All @@ -97,8 +108,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context

// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
logrus.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

item, err := s.svc.MakeBackendItem(sp, sp.GetName())
Expand Down Expand Up @@ -158,26 +180,3 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte

return nil
}

// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints.
func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter(err.Error())
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

for _, descriptor := range ed.SPSSODescriptors {
for _, acs := range descriptor.AssertionConsumerServices {
if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil {
return trace.Wrap(err)
}
}
}

return nil
}
117 changes: 0 additions & 117 deletions lib/services/local/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,8 @@ import (

"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -66,34 +64,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
})
require.NoError(t, err)

// Try to create a service provider with an invalid acs.
invalidSP, err := types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Try to create a service provider with a http acs.
invalidSP, err = types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newHTTPACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Initially we expect no service providers.
out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "")
require.NoError(t, err)
Expand Down Expand Up @@ -193,20 +163,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)

// Update a service provider with an invalid acs.
invalidSP, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
invalidSP.SetEntityDescriptor(newInvalidACSEntityDescriptor(invalidSP.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Update a service provider with a http acs.
invalidSP.SetEntityDescriptor(newHTTPACSEntityDescriptor(invalidSP.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "UpdateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Delete a service provider.
err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
Expand All @@ -230,49 +186,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
require.Empty(t, out)
}

func TestValidateSAMLIdPServiceProvider(t *testing.T) {
descriptor := newEntityDescriptor("IAMShowcase")

cases := []struct {
name string
spec types.SAMLIdPServiceProviderSpecV1
errAssertion require.ErrorAssertionFunc
}{
{
name: "valid provider",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: "IAMShowcase",
},
errAssertion: require.NoError,
},
{
name: "invalid entity id",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: uuid.NewString(),
},
errAssertion: require.Error,
},
{
name: "invalid acs",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"),
EntityID: "IAMShowcase",
},
errAssertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec)
require.NoError(t, err)
test.errAssertion(t, validateSAMLIdPServiceProvider(sp))
})
}
}

func newEntityDescriptor(entityID string) string {
return fmt.Sprintf(testEntityDescriptor, entityID)
}
Expand All @@ -287,33 +200,3 @@ const testEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newInvalidACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(invalidEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations.
const invalidEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="javascript://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newHTTPACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(httpEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with a http ACS locations.
const httpEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="http://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`
76 changes: 74 additions & 2 deletions lib/services/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,10 @@ import (
"context"
"net/url"

"github.com/crewjam/saml"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/slices"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -117,15 +120,84 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string)
return &s, nil
}

// supportedACSBindings is the set of AssertionConsumerService bindings that teleport supports.
var supportedACSBindings = map[string]struct{}{
saml.HTTPPostBinding: {},
saml.HTTPRedirectBinding: {},
}

// ValidateAssertionConsumerService checks if a given assertion consumer service is usable by teleport. Note that
// it is permissible for a service provider to include acs endpoints that are not compatible with teleport, so long
// as at least one _is_ compatible.
func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error {
if _, ok := supportedACSBindings[acs.Binding]; !ok {
return trace.BadParameter("unsupported acs binding: %q", acs.Binding)
}

if acs.Location == "" {
return trace.BadParameter("acs location endpoint is missing or could not be decoded for %q binding", acs.Binding)
}

return trace.Wrap(ValidateAssertionConsumerServicesEndpoint(acs.Location))
}

// FilterSAMLEntityDescriptor performs a filter in place to remove unsupported and/or insecure fields from
// a saml entity descriptor. Specifically, it removes acs endpoints that are either of an unsupported kind,
// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally
// expected that a service provider will successfully support a given ACS so long as they have at least one
// compatible binding.
func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error {
var originalCount int
var filteredCount int
for i := range ed.SPSSODescriptors {
filtered := deleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool {
if err := ValidateAssertionConsumerService(acs); err != nil {
if !quiet {
log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err)
}
return true
}

return false
})

originalCount += len(ed.SPSSODescriptors[i].AssertionConsumerServices)
filteredCount += len(filtered)

ed.SPSSODescriptors[i].AssertionConsumerServices = filtered
}

if filteredCount == 0 && originalCount != 0 {
return trace.BadParameter("no AssertionConsumerService bindings for entity %q passed validation", ed.EntityID)
}

return nil
}

func deleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i := slices.IndexFunc(s, del)
if i == -1 {
return s
}
// Don't start copying elements until we find one to delete.
for j := i + 1; j < len(s); j++ {
if v := s[j]; !del(v) {
s[i] = v
i++
}
}
return s[:i]
}

// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location
// is a valid HTTPS endpoint.
func ValidateAssertionConsumerServicesEndpoint(acs string) error {
endpoint, err := url.Parse(acs)
switch {
case err != nil:
return trace.Wrap(err)
return trace.BadParameter("acs location endpoint %q could not be parsed: %v", acs, err)
case endpoint.Scheme != "https":
return trace.BadParameter("the assertion consumer services location must be an https endpoint")
return trace.BadParameter("invalid scheme %q in acs location endpoint %q (must be 'https')", endpoint.Scheme, acs)
}

return nil
Expand Down
Loading