From aa586539bab7ca26043cb08f0f6ba5674cd8de3e Mon Sep 17 00:00:00 2001 From: Mike Tonks Date: Fri, 31 Jan 2020 15:17:21 +0000 Subject: [PATCH] Add EntityID --- identity_provider_test.go | 1 + samlidp/samlidp_test.go | 1 + samlsp/new.go | 6 ++++++ service_provider.go | 36 +++++++++++++++++++----------------- service_provider_test.go | 3 +++ 5 files changed, 30 insertions(+), 17 deletions(-) diff --git a/identity_provider_test.go b/identity_provider_test.go index cb92415d..8ceef682 100644 --- a/identity_provider_test.go +++ b/identity_provider_test.go @@ -113,6 +113,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== -----END CERTIFICATE----- `) test.SP = ServiceProvider{ + EntityID: "https://sp.example.com/saml2/metadata", Key: test.SPKey, Certificate: test.SPCertificate, MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"), diff --git a/samlidp/samlidp_test.go b/samlidp/samlidp_test.go index 89e833e2..87269fe5 100644 --- a/samlidp/samlidp_test.go +++ b/samlidp/samlidp_test.go @@ -114,6 +114,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ== -----END CERTIFICATE----- `) test.SP = saml.ServiceProvider{ + EntityID: "https://sp.example.com/saml2/metadata", Key: test.SPKey, Certificate: test.SPCertificate, MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"), diff --git a/samlsp/new.go b/samlsp/new.go index 82bfd835..5bbf7f3e 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -15,6 +15,7 @@ import ( // Options represents the parameters for creating a new middleware type Options struct { + EntityID string URL url.URL Key *rsa.PrivateKey Certificate *x509.Certificate @@ -126,6 +127,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } return saml.ServiceProvider{ + EntityID: opts.EntityID, Key: opts.Key, Certificate: opts.Certificate, Intermediates: opts.Intermediates, @@ -164,6 +166,10 @@ func New(opts Options) (*Middleware, error) { opts.IDPMetadata = metadata } + if opts.EntityID == "" && opts.IDPMetadataURL != nil { + opts.EntityID = opts.IDPMetadataURL.String() + } + m := &Middleware{ ServiceProvider: DefaultServiceProvider(opts), Binding: "", diff --git a/service_provider.go b/service_provider.go index 7780d881..dfee1112 100644 --- a/service_provider.go +++ b/service_provider.go @@ -58,6 +58,9 @@ type SignatureVerifier interface { // See the example directory for an example of a web application using // the service provider interface. type ServiceProvider struct { + // Entity ID + EntityID string + // Key is the RSA private key we use to sign requests. Key *rsa.PrivateKey @@ -156,7 +159,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { } return &EntityDescriptor{ - EntityID: sp.MetadataURL.String(), + EntityID: sp.EntityID, ValidUntil: validUntil, SPSSODescriptors: []SPSSODescriptor{ @@ -316,7 +319,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque Version: "2.0", Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: sp.EntityID, }, NameIDPolicy: &NameIDPolicy{ AllowCreate: &allowCreate, @@ -655,12 +658,12 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0 for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions { - if audienceRestriction.Audience.Value == sp.MetadataURL.String() { + if audienceRestriction.Audience.Value == sp.EntityID { audienceRestrictionsValid = true } } if !audienceRestrictionsValid { - return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String()) + return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.EntityID) } return nil } @@ -800,7 +803,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ Destination: idpURL, Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: sp.EntityID, }, NameID: &NameID{ Format: sp.nameIDFormat(), @@ -843,7 +846,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) erro if err != nil { return fmt.Errorf("unable to parse form: %v", err) } - + return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse")) } @@ -855,11 +858,11 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error } var resp LogoutResponse - + if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil { return fmt.Errorf("cannot unmarshal response: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } @@ -868,7 +871,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error if err := doc.ReadFromBytes(rawResponseBuf); err != nil { return err } - + responseEl := doc.Root() if err = sp.validateSigned(responseEl); err != nil { return err @@ -887,20 +890,20 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str } gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf)) - + decoder := xml.NewDecoder(gr) - + var resp LogoutResponse - + err = decoder.Decode(&resp) if err != nil { return fmt.Errorf("unable to flate decode: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } - + doc := etree.NewDocument() if _, err := doc.ReadFrom(gr); err != nil { return err @@ -914,7 +917,6 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str return nil } - // validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid. func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Destination != sp.SloURL.String() { @@ -931,6 +933,6 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Status.StatusCode.Value != StatusSuccess { return fmt.Errorf("status code was not %s", StatusSuccess) } - + return nil -} \ No newline at end of file +} diff --git a/service_provider_test.go b/service_provider_test.go index ce370edc..ff042381 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -101,6 +101,7 @@ func TestSPCanSetAuthenticationNameIDFormat(t *testing.T) { func TestSPCanProduceMetadata(t *testing.T) { test := NewServiceProviderTest() s := ServiceProvider{ + EntityID: "https://example.com/saml2/metadata", Key: test.Key, Certificate: test.Certificate, MetadataURL: mustParseURL("https://example.com/saml2/metadata"), @@ -144,6 +145,7 @@ func TestSPCanProduceMetadata(t *testing.T) { func TestCanProduceMetadataNoSigningKey(t *testing.T) { test := NewServiceProviderTest() s := ServiceProvider{ + EntityID: "https://example.com/saml2/metadata", MetadataURL: mustParseURL("https://example.com/saml2/metadata"), AcsURL: mustParseURL("https://example.com/saml2/acs"), IDPMetadata: &EntityDescriptor{}, @@ -171,6 +173,7 @@ func TestSPCanProduceRedirectRequest(t *testing.T) { } Clock = dsig.NewFakeClockAt(TimeNow()) s := ServiceProvider{ + EntityID: "https://15661444.ngrok.io/saml2/metadata", Key: test.Key, Certificate: test.Certificate, MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"),