From b55e20e6723a43c27dae08f85c31e49f7cd983a1 Mon Sep 17 00:00:00 2001 From: sshahcodes Date: Thu, 29 Feb 2024 18:31:19 -0500 Subject: [PATCH 1/2] fix: : return non-nil error regardless of http status --- lib/services/local/saml_idp_service_provider.go | 15 ++++++++------- .../local/saml_idp_service_provider_test.go | 7 +++++++ 2 files changed, 15 insertions(+), 7 deletions(-) diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index 21a0ea4e8ec9b..ffc86cdde5ff5 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -112,13 +112,14 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co // CreateSAMLIdPServiceProvider creates a new SAML IdP service provider resource. func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { if sp.GetEntityDescriptor() == "" { - if err := s.fetchAndSetEntityDescriptor(sp); err != nil { + // fetchAndSetEntityDescriptor is expected to return error if it fails + // to set entity descriptor. Let's still be defensive and double check + // for sp.GetEntityDescriptor() value. + if err := s.fetchAndSetEntityDescriptor(sp); err != nil || sp.GetEntityDescriptor() == "" { // We aren't interested in checking error type as any occurrence of error mean entity descriptor was not set. - // But a debug log should be helpful to indicate source of error. - s.log.Debugf("Failed to fetch entity descriptor from %s. %s.", sp.GetEntityID(), err.Error()) - if err := s.generateAndSetEntityDescriptor(sp); err != nil { - return trace.BadParameter("could not generate entity descriptor with given entity_id and acs_url.") + return trace.BadParameter("could not generate entity descriptor with given entity_id %q and acs_url %q: %v", + sp.GetEntityID(), sp.GetACSURL(), err) } } } @@ -256,7 +257,7 @@ func (s *SAMLIdPServiceProviderService) fetchAndSetEntityDescriptor(sp types.SAM defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return trace.Wrap(trace.ReadError(resp.StatusCode, nil)) + return trace.NotFound("entity descriptor not found on the given endpoint") } body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize) @@ -277,7 +278,7 @@ func (s *SAMLIdPServiceProviderService) fetchAndSetEntityDescriptor(sp types.SAM // generateAndSetEntityDescriptor generates and sets Service Provider entity descriptor // with ACS URL, Entity ID and unspecified NameID format. func (s *SAMLIdPServiceProviderService) generateAndSetEntityDescriptor(sp types.SAMLIdPServiceProvider) error { - s.log.Infof("Generating a default entity_descriptor with entity_id %s and acs_url %s.", sp.GetEntityID(), sp.GetACSURL()) + s.log.Infof("Generating a default entity_descriptor with entity_id %q and acs_url %q.", sp.GetEntityID(), sp.GetACSURL()) acsURL, err := url.Parse(sp.GetACSURL()) if err != nil { diff --git a/lib/services/local/saml_idp_service_provider_test.go b/lib/services/local/saml_idp_service_provider_test.go index dbf0e88af378f..a5c0ce7d384df 100644 --- a/lib/services/local/saml_idp_service_provider_test.go +++ b/lib/services/local/saml_idp_service_provider_test.go @@ -323,6 +323,8 @@ func TestCreateSAMLIdPServiceProvider_fetchAndSetEntityDescriptor(t *testing.T) switch r.RequestURI { case "/status-not-ok": w.WriteHeader(http.StatusNotFound) + case "/status-302-found": + w.WriteHeader(http.StatusFound) case "/invalid-metadata": fmt.Fprintln(w, "test") default: @@ -344,6 +346,11 @@ func TestCreateSAMLIdPServiceProvider_fetchAndSetEntityDescriptor(t *testing.T) entityID: fmt.Sprintf("%s/status-not-ok", testSPServer.URL), wantErr: true, }, + { + name: "status 302 found", + entityID: fmt.Sprintf("%s/status-302-found", testSPServer.URL), + wantErr: true, + }, { name: "invalid metadata", entityID: fmt.Sprintf("%s/invalid-metadata", testSPServer.URL), From ee1b015ce0448c3684b544398e20d1aef21e0aa4 Mon Sep 17 00:00:00 2001 From: sshahcodes Date: Tue, 5 Mar 2024 13:13:58 -0500 Subject: [PATCH 2/2] use trace.Badparameter, bring debug log back, remove defensive sp.GetEntityDescriptor() == check --- lib/services/local/saml_idp_service_provider.go | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/lib/services/local/saml_idp_service_provider.go b/lib/services/local/saml_idp_service_provider.go index ffc86cdde5ff5..66eb1eea7693c 100644 --- a/lib/services/local/saml_idp_service_provider.go +++ b/lib/services/local/saml_idp_service_provider.go @@ -113,9 +113,9 @@ func (s *SAMLIdPServiceProviderService) GetSAMLIdPServiceProvider(ctx context.Co func (s *SAMLIdPServiceProviderService) CreateSAMLIdPServiceProvider(ctx context.Context, sp types.SAMLIdPServiceProvider) error { if sp.GetEntityDescriptor() == "" { // fetchAndSetEntityDescriptor is expected to return error if it fails - // to set entity descriptor. Let's still be defensive and double check - // for sp.GetEntityDescriptor() value. - if err := s.fetchAndSetEntityDescriptor(sp); err != nil || sp.GetEntityDescriptor() == "" { + // to fetch a valid entity descriptor. + if err := s.fetchAndSetEntityDescriptor(sp); err != nil { + s.log.Debugf("Failed to fetch entity descriptor from %q. %v.", sp.GetEntityID(), err) // We aren't interested in checking error type as any occurrence of error mean entity descriptor was not set. if err := s.generateAndSetEntityDescriptor(sp); err != nil { return trace.BadParameter("could not generate entity descriptor with given entity_id %q and acs_url %q: %v", @@ -257,7 +257,7 @@ func (s *SAMLIdPServiceProviderService) fetchAndSetEntityDescriptor(sp types.SAM defer resp.Body.Close() if resp.StatusCode != http.StatusOK { - return trace.NotFound("entity descriptor not found on the given endpoint") + return trace.Wrap(trace.BadParameter("unexpected response status: %q", resp.StatusCode)) } body, err := utils.ReadAtMost(resp.Body, teleport.MaxHTTPResponseSize)