Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
71 changes: 35 additions & 36 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@ package local
import (
"context"
"encoding/xml"
"errors"
"io"
"net/http"
"net/url"
"time"
Expand Down Expand Up @@ -123,8 +121,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
}
}

if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

// embed attribute mapping in entity descriptor
Expand Down Expand Up @@ -153,8 +162,19 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context

// UpdateSAMLIdPServiceProvider updates an existing SAML IdP service provider resource.
func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error {
if err := validateSAMLIdPServiceProvider(sp); err != nil {
return trace.Wrap(err)
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

// embed attribute mapping in entity descriptor
Expand Down Expand Up @@ -220,34 +240,6 @@ func (s *SAMLIdPServiceProviderService) ensureEntityIDIsUnique(ctx context.Conte
return nil
}

// validateSAMLIdPServiceProvider ensures that the entity ID in the entity descriptor is the same as the entity ID
// in the [types.SAMLIdPServiceProvider] and that all AssertionConsumerServices defined are valid HTTPS endpoints.
func validateSAMLIdPServiceProvider(sp types.SAMLIdPServiceProvider) error {
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
switch {
case errors.Is(err, io.EOF):
return trace.BadParameter("missing entity descriptor: %s", err.Error())
default:
return trace.BadParameter(err.Error())
}
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

for _, descriptor := range ed.SPSSODescriptors {
for _, acs := range descriptor.AssertionConsumerServices {
if err := services.ValidateAssertionConsumerServicesEndpoint(acs.Location); err != nil {
return trace.Wrap(err)
}
}
}

return nil
}

// fetchAndSetEntityDescriptor fetches Service Provider entity descriptor (aka SP metadata)
// from remote metadata endpoint (Entity ID) and sets it to sp if the xml format
// is a valid Service Provider metadata format.
Expand Down Expand Up @@ -296,12 +288,19 @@ func (s *SAMLIdPServiceProviderService) generateAndSetEntityDescriptor(sp types.
AuthnNameIDFormat: saml.UnspecifiedNameIDFormat,
}

entityDescriptor, err := xml.MarshalIndent(newServiceProvider.Metadata(), "", " ")
ed := newServiceProvider.Metadata()
// HTTPArtifactBinding is defined when entity descriptor is generated
// using crewjam/saml https://github.com/crewjam/saml/blob/main/service_provider.go#L228.
// But we do not support it, so filter it out below.
// Error and warnings are swallowed because the descriptor is Teleport generated and
// users have no control over sanitizing filtered binding.
services.FilterSAMLEntityDescriptor(ed, true /* quiet */)
edXMLBytes, err := xml.MarshalIndent(ed, "", " ")
if err != nil {
return trace.Wrap(err)
}

sp.SetEntityDescriptor(string(entityDescriptor))
sp.SetEntityDescriptor(string(edXMLBytes))
return nil
}

Expand Down
120 changes: 1 addition & 119 deletions lib/services/local/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,8 @@ import (
"github.com/crewjam/saml/samlsp"
"github.com/google/go-cmp/cmp"
"github.com/google/go-cmp/cmp/cmpopts"
"github.com/google/uuid"
"github.com/gravitational/trace"
"github.com/jonboulle/clockwork"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"

"github.com/gravitational/teleport/api/types"
Expand Down Expand Up @@ -71,34 +69,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
})
require.NoError(t, err)

// Try to create a service provider with an invalid acs.
invalidSP, err := types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Try to create a service provider with a http acs.
invalidSP, err = types.NewSAMLIdPServiceProvider(
types.Metadata{
Name: "sp3",
},
types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newHTTPACSEntityDescriptor("sp1"),
EntityID: "sp1",
})
require.NoError(t, err)
err = service.CreateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Initially we expect no service providers.
out, nextToken, err := service.ListSAMLIdPServiceProviders(ctx, 200, "")
require.NoError(t, err)
Expand Down Expand Up @@ -198,20 +168,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
err = service.UpdateSAMLIdPServiceProvider(ctx, sp)
require.Error(t, err)

// Update a service provider with an invalid acs.
invalidSP, err = service.GetSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
invalidSP.SetEntityDescriptor(newInvalidACSEntityDescriptor(invalidSP.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "CreateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Update a service provider with a http acs.
invalidSP.SetEntityDescriptor(newHTTPACSEntityDescriptor(invalidSP.GetEntityID()))
err = service.UpdateSAMLIdPServiceProvider(ctx, invalidSP)
assert.Error(t, err)
require.True(t, trace.IsBadParameter(err), "UpdateSAMLIdPServiceProvider error mismatch, wanted BadParameter, got %q (%T)", err, trace.Unwrap(err))

// Delete a service provider.
err = service.DeleteSAMLIdPServiceProvider(ctx, sp1.GetName())
require.NoError(t, err)
Expand All @@ -235,49 +191,6 @@ func TestSAMLIdPServiceProviderCRUD(t *testing.T) {
require.Empty(t, out)
}

func TestValidateSAMLIdPServiceProvider(t *testing.T) {
descriptor := newEntityDescriptor("IAMShowcase")

cases := []struct {
name string
spec types.SAMLIdPServiceProviderSpecV1
errAssertion require.ErrorAssertionFunc
}{
{
name: "valid provider",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: "IAMShowcase",
},
errAssertion: require.NoError,
},
{
name: "invalid entity id",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: descriptor,
EntityID: uuid.NewString(),
},
errAssertion: require.Error,
},
{
name: "invalid acs",
spec: types.SAMLIdPServiceProviderSpecV1{
EntityDescriptor: newInvalidACSEntityDescriptor("IAMShowcase"),
EntityID: "IAMShowcase",
},
errAssertion: require.Error,
},
}

for _, test := range cases {
t.Run(test.name, func(t *testing.T) {
sp, err := types.NewSAMLIdPServiceProvider(types.Metadata{Name: "sp"}, test.spec)
require.NoError(t, err)
test.errAssertion(t, validateSAMLIdPServiceProvider(sp))
})
}
}

func newEntityDescriptor(entityID string) string {
return fmt.Sprintf(testEntityDescriptor, entityID)
}
Expand All @@ -293,46 +206,15 @@ const testEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
</md:EntityDescriptor>
`

func newInvalidACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(invalidEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with invalid ACS locations.
const invalidEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="javascript://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newHTTPACSEntityDescriptor(entityID string) string {
return fmt.Sprintf(httpEntityDescriptor, entityID)
}

// A test entity descriptor from https://sptest.iamshowcase.com/testsp_metadata.xml with a http ACS locations.
const httpEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
<md:EntityDescriptor xmlns:md="urn:oasis:names:tc:SAML:2.0:metadata" xmlns:ds="http://www.w3.org/2000/09/xmldsig#" entityID="%s" validUntil="2025-12-09T09:13:31.006Z">
<md:SPSSODescriptor AuthnRequestsSigned="false" WantAssertionsSigned="true" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol">
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</md:NameIDFormat>
<md:NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:emailAddress</md:NameIDFormat>
<md:AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="http://sptest.iamshowcase.com/acs" index="0" isDefault="true"/>
</md:SPSSODescriptor>
</md:EntityDescriptor>
`

func newSAMLSPMetadata(entityID, acsURL string) string {
return fmt.Sprintf(samlSPMetadata, entityID, acsURL, acsURL)
return fmt.Sprintf(samlSPMetadata, entityID, acsURL)
}

// samlSPMetadata mimics metadata generated by saml.ServiceProvider.Metadata()
const samlSPMetadata = `<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" validUntil="2023-12-09T23:43:58.16Z" entityID="%s">
<SPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" validUntil="2023-12-09T23:43:58.16Z" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol" AuthnRequestsSigned="false" WantAssertionsSigned="true">
<NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</NameIDFormat>
<AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="%s" index="1"></AssertionConsumerService>
<AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact" Location="%s" index="2"></AssertionConsumerService>
</SPSSODescriptor>
</EntityDescriptor>
`
Expand Down
61 changes: 59 additions & 2 deletions lib/services/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,11 @@ package services
import (
"context"
"net/url"
"slices"

"github.com/crewjam/saml"
"github.com/gravitational/trace"
log "github.com/sirupsen/logrus"

"github.com/gravitational/teleport/api/types"
"github.com/gravitational/teleport/lib/utils"
Expand Down Expand Up @@ -125,15 +128,69 @@ func GenerateIdPServiceProviderFromFields(name string, entityDescriptor string)
return &s, nil
}

// supportedACSBindings is the set of AssertionConsumerService bindings that teleport supports.
var supportedACSBindings = map[string]struct{}{
saml.HTTPPostBinding: {},
saml.HTTPRedirectBinding: {},
}

// ValidateAssertionConsumerService checks if a given assertion consumer service is usable by teleport. Note that
// it is permissible for a service provider to include acs endpoints that are not compatible with teleport, so long
// as at least one _is_ compatible.
func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error {
if _, ok := supportedACSBindings[acs.Binding]; !ok {
return trace.BadParameter("unsupported acs binding: %q", acs.Binding)
}

if acs.Location == "" {
return trace.BadParameter("acs location endpoint is missing or could not be decoded for %q binding", acs.Binding)
}

return trace.Wrap(ValidateAssertionConsumerServicesEndpoint(acs.Location))
}

// FilterSAMLEntityDescriptor performs a filter in place to remove unsupported and/or insecure fields from
// a saml entity descriptor. Specifically, it removes acs endpoints that are either of an unsupported kind,
// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally
// expected that a service provider will successfully support a given ACS so long as they have at least one
// compatible binding.
func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error {
var originalCount int
var filteredCount int
for i := range ed.SPSSODescriptors {
filtered := slices.DeleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool {
if err := ValidateAssertionConsumerService(acs); err != nil {
if !quiet {
log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err)
}
return true
}

return false
})

originalCount += len(ed.SPSSODescriptors[i].AssertionConsumerServices)
filteredCount += len(filtered)

ed.SPSSODescriptors[i].AssertionConsumerServices = filtered
}

if filteredCount == 0 && originalCount != 0 {
return trace.BadParameter("no AssertionConsumerService bindings for entity %q passed validation", ed.EntityID)
}

return nil
}

// ValidateAssertionConsumerServicesEndpoint ensures that the Assertion Consumer Service location
// is a valid HTTPS endpoint.
func ValidateAssertionConsumerServicesEndpoint(acs string) error {
endpoint, err := url.Parse(acs)
switch {
case err != nil:
return trace.Wrap(err)
return trace.BadParameter("acs location endpoint %q could not be parsed: %v", acs, err)
case endpoint.Scheme != "https":
return trace.BadParameter("the assertion consumer services location must be an https endpoint")
return trace.BadParameter("invalid scheme %q in acs location endpoint %q (must be 'https')", endpoint.Scheme, acs)
}

return nil
Expand Down
Loading