From a2ca41268c392ea69c89ca98293b5bfac64ab3df Mon Sep 17 00:00:00 2001 From: Mike Tonks Date: Mon, 3 Feb 2020 11:42:14 +0000 Subject: [PATCH] Add EntityID --- samlsp/new.go | 2 ++ service_provider.go | 44 ++++++++++++++++++++++++---------------- service_provider_test.go | 23 +++++++++++++++++++++ 3 files changed, 52 insertions(+), 17 deletions(-) diff --git a/samlsp/new.go b/samlsp/new.go index 82bfd835..451a65aa 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -15,6 +15,7 @@ import ( // Options represents the parameters for creating a new middleware type Options struct { + EntityID string URL url.URL Key *rsa.PrivateKey Certificate *x509.Certificate @@ -126,6 +127,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } return saml.ServiceProvider{ + EntityID: opts.EntityID, Key: opts.Key, Certificate: opts.Certificate, Intermediates: opts.Intermediates, diff --git a/service_provider.go b/service_provider.go index 7780d881..d2ca89b3 100644 --- a/service_provider.go +++ b/service_provider.go @@ -58,6 +58,9 @@ type SignatureVerifier interface { // See the example directory for an example of a web application using // the service provider interface. type ServiceProvider struct { + // Entity ID is optional - if not specified then MetadataURL will be used + EntityID string + // Key is the RSA private key we use to sign requests. Key *rsa.PrivateKey @@ -156,7 +159,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { } return &EntityDescriptor{ - EntityID: sp.MetadataURL.String(), + EntityID: firstSet(sp.EntityID, sp.MetadataURL.String()), ValidUntil: validUntil, SPSSODescriptors: []SPSSODescriptor{ @@ -316,7 +319,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque Version: "2.0", Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: firstSet(sp.EntityID, sp.MetadataURL.String()), }, NameIDPolicy: &NameIDPolicy{ AllowCreate: &allowCreate, @@ -654,13 +657,14 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque } audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0 + audience := firstSet(sp.EntityID, sp.MetadataURL.String()) for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions { - if audienceRestriction.Audience.Value == sp.MetadataURL.String() { + if audienceRestriction.Audience.Value == audience { audienceRestrictionsValid = true } } if !audienceRestrictionsValid { - return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String()) + return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", audience) } return nil } @@ -800,7 +804,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ Destination: idpURL, Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: firstSet(sp.EntityID, sp.MetadataURL.String()), }, NameID: &NameID{ Format: sp.nameIDFormat(), @@ -843,7 +847,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) erro if err != nil { return fmt.Errorf("unable to parse form: %v", err) } - + return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse")) } @@ -855,11 +859,11 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error } var resp LogoutResponse - + if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil { return fmt.Errorf("cannot unmarshal response: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } @@ -868,7 +872,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error if err := doc.ReadFromBytes(rawResponseBuf); err != nil { return err } - + responseEl := doc.Root() if err = sp.validateSigned(responseEl); err != nil { return err @@ -887,20 +891,20 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str } gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf)) - + decoder := xml.NewDecoder(gr) - + var resp LogoutResponse - + err = decoder.Decode(&resp) if err != nil { return fmt.Errorf("unable to flate decode: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } - + doc := etree.NewDocument() if _, err := doc.ReadFrom(gr); err != nil { return err @@ -914,7 +918,6 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str return nil } - // validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid. func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Destination != sp.SloURL.String() { @@ -931,6 +934,13 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Status.StatusCode.Value != StatusSuccess { return fmt.Errorf("status code was not %s", StatusSuccess) } - + return nil -} \ No newline at end of file +} + +func firstSet(a, b string) string { + if a == "" { + return b + } + return a +} diff --git a/service_provider_test.go b/service_provider_test.go index ce370edc..005e8453 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -163,6 +163,29 @@ func TestCanProduceMetadataNoSigningKey(t *testing.T) { string(spMetadata)) } +func TestCanProduceMetadataEntityID(t *testing.T) { + test := NewServiceProviderTest() + s := ServiceProvider{ + EntityID: "spn:11111111-2222-3333-4444-555555555555", + MetadataURL: mustParseURL("https://example.com/saml2/metadata"), + AcsURL: mustParseURL("https://example.com/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + assert.NoError(t, err) + + spMetadata, err := xml.MarshalIndent(s.Metadata(), "", " ") + assert.NoError(t, err) + assert.Equal(t, ""+ + "\n"+ + " \n"+ + " \n"+ + " \n"+ + " \n"+ + "", + string(spMetadata)) +} + func TestSPCanProduceRedirectRequest(t *testing.T) { test := NewServiceProviderTest() TimeNow = func() time.Time {