Skip to content

Commit

Permalink
Add EntityID
Browse files Browse the repository at this point in the history
  • Loading branch information
miketonks committed Feb 3, 2020
1 parent eefb3b2 commit a2ca412
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 17 deletions.
2 changes: 2 additions & 0 deletions samlsp/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

// Options represents the parameters for creating a new middleware
type Options struct {
EntityID string
URL url.URL
Key *rsa.PrivateKey
Certificate *x509.Certificate
Expand Down Expand Up @@ -126,6 +127,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
}

return saml.ServiceProvider{
EntityID: opts.EntityID,
Key: opts.Key,
Certificate: opts.Certificate,
Intermediates: opts.Intermediates,
Expand Down
44 changes: 27 additions & 17 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type SignatureVerifier interface {
// See the example directory for an example of a web application using
// the service provider interface.
type ServiceProvider struct {
// Entity ID is optional - if not specified then MetadataURL will be used
EntityID string

// Key is the RSA private key we use to sign requests.
Key *rsa.PrivateKey

Expand Down Expand Up @@ -156,7 +159,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor {
}

return &EntityDescriptor{
EntityID: sp.MetadataURL.String(),
EntityID: firstSet(sp.EntityID, sp.MetadataURL.String()),
ValidUntil: validUntil,

SPSSODescriptors: []SPSSODescriptor{
Expand Down Expand Up @@ -316,7 +319,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque
Version: "2.0",
Issuer: &Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: sp.MetadataURL.String(),
Value: firstSet(sp.EntityID, sp.MetadataURL.String()),
},
NameIDPolicy: &NameIDPolicy{
AllowCreate: &allowCreate,
Expand Down Expand Up @@ -654,13 +657,14 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque
}

audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0
audience := firstSet(sp.EntityID, sp.MetadataURL.String())
for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions {
if audienceRestriction.Audience.Value == sp.MetadataURL.String() {
if audienceRestriction.Audience.Value == audience {
audienceRestrictionsValid = true
}
}
if !audienceRestrictionsValid {
return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String())
return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", audience)
}
return nil
}
Expand Down Expand Up @@ -800,7 +804,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ
Destination: idpURL,
Issuer: &Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: sp.MetadataURL.String(),
Value: firstSet(sp.EntityID, sp.MetadataURL.String()),
},
NameID: &NameID{
Format: sp.nameIDFormat(),
Expand Down Expand Up @@ -843,7 +847,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) erro
if err != nil {
return fmt.Errorf("unable to parse form: %v", err)
}

return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse"))
}

Expand All @@ -855,11 +859,11 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
}

var resp LogoutResponse

if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil {
return fmt.Errorf("cannot unmarshal response: %s", err)
}

if err := sp.validateLogoutResponse(&resp); err != nil {
return err
}
Expand All @@ -868,7 +872,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
return err
}

responseEl := doc.Root()
if err = sp.validateSigned(responseEl); err != nil {
return err
Expand All @@ -887,20 +891,20 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
}

gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf))

decoder := xml.NewDecoder(gr)

var resp LogoutResponse

err = decoder.Decode(&resp)
if err != nil {
return fmt.Errorf("unable to flate decode: %s", err)
}

if err := sp.validateLogoutResponse(&resp); err != nil {
return err
}

doc := etree.NewDocument()
if _, err := doc.ReadFrom(gr); err != nil {
return err
Expand All @@ -914,7 +918,6 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
return nil
}


// validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid.
func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
if resp.Destination != sp.SloURL.String() {
Expand All @@ -931,6 +934,13 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
if resp.Status.StatusCode.Value != StatusSuccess {
return fmt.Errorf("status code was not %s", StatusSuccess)
}

return nil
}
}

func firstSet(a, b string) string {
if a == "" {
return b
}
return a
}
23 changes: 23 additions & 0 deletions service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,29 @@ func TestCanProduceMetadataNoSigningKey(t *testing.T) {
string(spMetadata))
}

func TestCanProduceMetadataEntityID(t *testing.T) {
test := NewServiceProviderTest()
s := ServiceProvider{
EntityID: "spn:11111111-2222-3333-4444-555555555555",
MetadataURL: mustParseURL("https://example.com/saml2/metadata"),
AcsURL: mustParseURL("https://example.com/saml2/acs"),
IDPMetadata: &EntityDescriptor{},
}
err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata)
assert.NoError(t, err)

spMetadata, err := xml.MarshalIndent(s.Metadata(), "", " ")
assert.NoError(t, err)
assert.Equal(t, ""+
"<EntityDescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2015-12-03T01:57:09Z\" entityID=\"spn:11111111-2222-3333-4444-555555555555\">\n"+
" <SPSSODescriptor xmlns=\"urn:oasis:names:tc:SAML:2.0:metadata\" validUntil=\"2015-12-03T01:57:09Z\" protocolSupportEnumeration=\"urn:oasis:names:tc:SAML:2.0:protocol\" AuthnRequestsSigned=\"false\" WantAssertionsSigned=\"true\">\n"+
" <SingleLogoutService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"\"></SingleLogoutService>\n"+
" <AssertionConsumerService Binding=\"urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST\" Location=\"https://example.com/saml2/acs\" index=\"1\"></AssertionConsumerService>\n"+
" </SPSSODescriptor>\n"+
"</EntityDescriptor>",
string(spMetadata))
}

func TestSPCanProduceRedirectRequest(t *testing.T) {
test := NewServiceProviderTest()
TimeNow = func() time.Time {
Expand Down

0 comments on commit a2ca412

Please sign in to comment.