diff --git a/lib/services/saml.go b/lib/services/saml.go index 354592ff51245..4f58ac4cd0901 100644 --- a/lib/services/saml.go +++ b/lib/services/saml.go @@ -39,6 +39,28 @@ import ( "github.com/gravitational/teleport/lib/utils" ) +func setEntityDescriptorFromURL(sc types.SAMLConnector) error { + if sc.GetEntityDescriptorURL() == "" { + return nil + } + + resp, err := http.Get(sc.GetEntityDescriptorURL()) + if err != nil { + return trace.WrapWithMessage(err, "unable to fetch entity descriptor from %v for SAML connector %v", sc.GetEntityDescriptorURL(), sc.GetName()) + } + if resp.StatusCode != http.StatusOK { + return trace.BadParameter("status code %v when fetching from %v for SAML connector %v", resp.StatusCode, sc.GetEntityDescriptorURL(), sc.GetName()) + } + defer resp.Body.Close() + body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) + if err != nil { + return trace.Wrap(err) + } + sc.SetEntityDescriptor(string(body)) + log.Debugf("[SAML] Successfully fetched entity descriptor from %v for connector %v", sc.GetEntityDescriptorURL(), sc.GetName()) + return nil +} + // ValidateSAMLConnector validates the SAMLConnector and sets default values. // If a remote to fetch roles is specified, roles will be validated to exist. func ValidateSAMLConnector(sc types.SAMLConnector, rg RoleGetter) error { @@ -46,21 +68,8 @@ func ValidateSAMLConnector(sc types.SAMLConnector, rg RoleGetter) error { return trace.Wrap(err) } - if sc.GetEntityDescriptorURL() != "" { - resp, err := http.Get(sc.GetEntityDescriptorURL()) - if err != nil { - return trace.WrapWithMessage(err, "unable to fetch entity descriptor from %v for SAML connector %v", sc.GetEntityDescriptorURL(), sc.GetName()) - } - if resp.StatusCode != http.StatusOK { - return trace.BadParameter("status code %v when fetching from %v for SAML connector %v", resp.StatusCode, sc.GetEntityDescriptorURL(), sc.GetName()) - } - defer resp.Body.Close() - body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) - if err != nil { - return trace.Wrap(err) - } - sc.SetEntityDescriptor(string(body)) - log.Debugf("[SAML] Successfully fetched entity descriptor from %v for connector %v", sc.GetEntityDescriptorURL(), sc.GetName()) + if err := setEntityDescriptorFromURL(sc); err != nil { + log.Errorf("error loading entity descriptor from URL: %s", err) } if sc.GetEntityDescriptor() != "" {