Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 15 additions & 8 deletions lib/services/local/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,16 +126,16 @@ func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err)
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err)
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

// embed attribute mapping in entity descriptor
Expand Down Expand Up @@ -167,16 +167,16 @@ func (s *SAMLIdPServiceProviderService) UpdateSAMLIdPServiceProvider(ctx context
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err)
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

if ed.EntityID != sp.GetEntityID() {
return trace.BadParameter("entity ID parsed from the entity descriptor does not match the entity ID in the SAML IdP service provider object")
}

// ensure any filtering related issues get logged
if err := services.FilterSAMLEntityDescriptor(ed); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP service provider %q may be malformed: %v", sp.GetEntityID(), err)
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
s.log.Warnf("Entity descriptor for SAML IdP Service Provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}

// embed attribute mapping in entity descriptor
Expand Down Expand Up @@ -290,12 +290,19 @@ func (s *SAMLIdPServiceProviderService) generateAndSetEntityDescriptor(sp types.
AuthnNameIDFormat: saml.UnspecifiedNameIDFormat,
}

entityDescriptor, err := xml.MarshalIndent(newServiceProvider.Metadata(), "", " ")
ed := newServiceProvider.Metadata()
// HTTPArtifactBinding is defined when entity descriptor is generated
// using crewjam/saml https://github.com/crewjam/saml/blob/main/service_provider.go#L228.
// But we do not support it, so filter it out below.
// Error and warnings are swallowed because the descriptor is Teleport generated and
// users have no control over sanitizing filtered binding.
services.FilterSAMLEntityDescriptor(ed, true /* quiet */)
edXMLBytes, err := xml.MarshalIndent(ed, "", " ")
if err != nil {
return trace.Wrap(err)
}

sp.SetEntityDescriptor(string(entityDescriptor))
sp.SetEntityDescriptor(string(edXMLBytes))
return nil
}

Expand Down
3 changes: 1 addition & 2 deletions lib/services/local/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -209,15 +209,14 @@ const testEntityDescriptor = `<?xml version="1.0" encoding="UTF-8"?>
`

func newSAMLSPMetadata(entityID, acsURL string) string {
return fmt.Sprintf(samlSPMetadata, entityID, acsURL, acsURL)
return fmt.Sprintf(samlSPMetadata, entityID, acsURL)
}

// samlSPMetadata mimics metadata generated by saml.ServiceProvider.Metadata()
const samlSPMetadata = `<EntityDescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" validUntil="2023-12-09T23:43:58.16Z" entityID="%s">
<SPSSODescriptor xmlns="urn:oasis:names:tc:SAML:2.0:metadata" validUntil="2023-12-09T23:43:58.16Z" protocolSupportEnumeration="urn:oasis:names:tc:SAML:2.0:protocol" AuthnRequestsSigned="false" WantAssertionsSigned="true">
<NameIDFormat>urn:oasis:names:tc:SAML:1.1:nameid-format:unspecified</NameIDFormat>
<AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST" Location="%s" index="1"></AssertionConsumerService>
<AssertionConsumerService Binding="urn:oasis:names:tc:SAML:2.0:bindings:HTTP-Artifact" Location="%s" index="2"></AssertionConsumerService>
</SPSSODescriptor>
</EntityDescriptor>
`
Expand Down
6 changes: 4 additions & 2 deletions lib/services/saml_idp_service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -150,13 +150,15 @@ func ValidateAssertionConsumerService(acs saml.IndexedEndpoint) error {
// or are using a non-https endpoint. We perform filtering rather than outright rejection because it is generally
// expected that a service provider will successfully support a given ACS so long as they have at least one
// compatible binding.
func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor) error {
func FilterSAMLEntityDescriptor(ed *saml.EntityDescriptor, quiet bool) error {
var originalCount int
var filteredCount int
for i := range ed.SPSSODescriptors {
filtered := slices.DeleteFunc(ed.SPSSODescriptors[i].AssertionConsumerServices, func(acs saml.IndexedEndpoint) bool {
if err := ValidateAssertionConsumerService(acs); err != nil {
log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err)
if !quiet {
log.Warnf("AssertionConsumerService binding for entity %q is invalid and will be ignored: %v", ed.EntityID, err)
}
return true
}

Expand Down
7 changes: 4 additions & 3 deletions lib/services/saml_idp_service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -97,16 +97,17 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) {
{
eds: edBuilder().
ACS(saml.HTTPArtifactBinding, "https://example.com/acs").
ACS("urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST-SimpleSign", "https://example.com/POST-SimpleSign").
ACS(saml.HTTPPostBinding, "https://example.com/acs").
Done(),
ok: true,
before: 2,
before: 3,
after: 1,
name: "binding filtering",
},
{
eds: edBuilder().
ACS(saml.HTTPArtifactBinding, "https://example.com/acs").
ACS("urn:oasis:names:tc:SAML:2.0:bindings:PAOS", "https://example.com/ECP").
ACS(saml.HTTPPostBinding, "http://example.com/acs").
Done(),
ok: false,
Expand All @@ -123,7 +124,7 @@ func TestFilterSAMLEntityDescriptor(t *testing.T) {

require.Equal(t, tt.before, getACSCount(ed))

err = FilterSAMLEntityDescriptor(ed)
err = FilterSAMLEntityDescriptor(ed, false /* quiet */)
if !tt.ok {
require.Error(t, err)
return
Expand Down
19 changes: 10 additions & 9 deletions tool/tctl/common/resource_command.go
Original file line number Diff line number Diff line change
Expand Up @@ -999,16 +999,17 @@ func (rc *ResourceCommand) createSAMLIdPServiceProvider(ctx context.Context, cli
return trace.Wrap(err)
}

// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service provider %q: %v", sp.GetEntityID(), err)
}
if sp.GetEntityDescriptor() != "" {
// verify that entity descriptor parses
ed, err := samlsp.ParseMetadata([]byte(sp.GetEntityDescriptor()))
if err != nil {
return trace.BadParameter("invalid entity descriptor for SAML IdP Service Provider %q: %v", sp.GetEntityID(), err)
}

// try filtering the entity descriptor. if it can't be filtered down to a useable looking state, reject
// the creation attempt.
if err := services.FilterSAMLEntityDescriptor(ed); err != nil {
return trace.Wrap(err)
// issue warning about unsupported ACS bindings.
if err := services.FilterSAMLEntityDescriptor(ed, false /* quiet */); err != nil {
log.Warnf("Entity descriptor for SAML IdP service provider %q contains unsupported ACS bindings: %v", sp.GetEntityID(), err)
}
}

serviceProviderName := sp.GetName()
Expand Down