diff --git a/samlsp/new.go b/samlsp/new.go index 82bfd835..451a65aa 100644 --- a/samlsp/new.go +++ b/samlsp/new.go @@ -15,6 +15,7 @@ import ( // Options represents the parameters for creating a new middleware type Options struct { + EntityID string URL url.URL Key *rsa.PrivateKey Certificate *x509.Certificate @@ -126,6 +127,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider { } return saml.ServiceProvider{ + EntityID: opts.EntityID, Key: opts.Key, Certificate: opts.Certificate, Intermediates: opts.Intermediates, diff --git a/service_provider.go b/service_provider.go index 55435ccf..e64ad065 100644 --- a/service_provider.go +++ b/service_provider.go @@ -58,6 +58,9 @@ type SignatureVerifier interface { // See the example directory for an example of a web application using // the service provider interface. type ServiceProvider struct { + // Entity ID is optional - if not specified then MetadataURL will be used + EntityID string + // Key is the RSA private key we use to sign requests. Key *rsa.PrivateKey @@ -156,7 +159,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor { } return &EntityDescriptor{ - EntityID: sp.MetadataURL.String(), + EntityID: firstSet(sp.EntityID, sp.MetadataURL.String()), ValidUntil: validUntil, SPSSODescriptors: []SPSSODescriptor{ @@ -316,7 +319,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque Version: "2.0", Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: firstSet(sp.EntityID, sp.MetadataURL.String()), }, NameIDPolicy: &NameIDPolicy{ AllowCreate: &allowCreate, @@ -656,13 +659,14 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque } audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0 + audience := firstSet(sp.EntityID, sp.MetadataURL.String()) for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions { - if audienceRestriction.Audience.Value == sp.MetadataURL.String() { + if audienceRestriction.Audience.Value == audience { audienceRestrictionsValid = true } } if !audienceRestrictionsValid { - return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String()) + return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", audience) } return nil } @@ -802,7 +806,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ Destination: idpURL, Issuer: &Issuer{ Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity", - Value: sp.MetadataURL.String(), + Value: firstSet(sp.EntityID, sp.MetadataURL.String()), }, NameID: &NameID{ Format: sp.nameIDFormat(), @@ -935,3 +939,10 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { return nil } + +func firstSet(a, b string) string { + if a == "" { + return b + } + return a +} diff --git a/service_provider_test.go b/service_provider_test.go index ce370edc..005e8453 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -163,6 +163,29 @@ func TestCanProduceMetadataNoSigningKey(t *testing.T) { string(spMetadata)) } +func TestCanProduceMetadataEntityID(t *testing.T) { + test := NewServiceProviderTest() + s := ServiceProvider{ + EntityID: "spn:11111111-2222-3333-4444-555555555555", + MetadataURL: mustParseURL("https://example.com/saml2/metadata"), + AcsURL: mustParseURL("https://example.com/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + assert.NoError(t, err) + + spMetadata, err := xml.MarshalIndent(s.Metadata(), "", " ") + assert.NoError(t, err) + assert.Equal(t, ""+ + "\n"+ + " \n"+ + " \n"+ + " \n"+ + " \n"+ + "", + string(spMetadata)) +} + func TestSPCanProduceRedirectRequest(t *testing.T) { test := NewServiceProviderTest() TimeNow = func() time.Time {