Skip to content
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions lib/services/saml.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,28 +39,37 @@ import (
"github.com/gravitational/teleport/lib/utils"
)

func setEntityDescriptorFromURL(sc types.SAMLConnector) error {
if sc.GetEntityDescriptorURL() == "" {
return nil
}

resp, err := http.Get(sc.GetEntityDescriptorURL())
if err != nil {
return trace.WrapWithMessage(err, "unable to fetch entity descriptor from %v for SAML connector %v", sc.GetEntityDescriptorURL(), sc.GetName())
}
if resp.StatusCode != http.StatusOK {
return trace.BadParameter("status code %v when fetching from %v for SAML connector %v", resp.StatusCode, sc.GetEntityDescriptorURL(), sc.GetName())
}
defer resp.Body.Close()
body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
if err != nil {
return trace.Wrap(err)
}
sc.SetEntityDescriptor(string(body))
log.Debugf("[SAML] Successfully fetched entity descriptor from %v for connector %v", sc.GetEntityDescriptorURL(), sc.GetName())
return nil
}

// ValidateSAMLConnector validates the SAMLConnector and sets default values.
// If a remote to fetch roles is specified, roles will be validated to exist.
func ValidateSAMLConnector(sc types.SAMLConnector, rg RoleGetter) error {
if err := sc.CheckAndSetDefaults(); err != nil {
return trace.Wrap(err)
}

if sc.GetEntityDescriptorURL() != "" {
resp, err := http.Get(sc.GetEntityDescriptorURL())
if err != nil {
return trace.WrapWithMessage(err, "unable to fetch entity descriptor from %v for SAML connector %v", sc.GetEntityDescriptorURL(), sc.GetName())
}
if resp.StatusCode != http.StatusOK {
return trace.BadParameter("status code %v when fetching from %v for SAML connector %v", resp.StatusCode, sc.GetEntityDescriptorURL(), sc.GetName())
}
defer resp.Body.Close()
body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)
if err != nil {
return trace.Wrap(err)
}
sc.SetEntityDescriptor(string(body))
log.Debugf("[SAML] Successfully fetched entity descriptor from %v for connector %v", sc.GetEntityDescriptorURL(), sc.GetName())
if err := setEntityDescriptorFromURL(sc); err != nil {
log.Errorf("error loading entity descriptor from URL: %s", err)
}

if sc.GetEntityDescriptor() != "" {
Expand Down