Skip to content

Commit

Permalink
Add EntityID
Browse files Browse the repository at this point in the history
  • Loading branch information
miketonks committed Jan 31, 2020
1 parent eefb3b2 commit aa58653
Show file tree
Hide file tree
Showing 5 changed files with 30 additions and 17 deletions.
1 change: 1 addition & 0 deletions identity_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
-----END CERTIFICATE-----
`)
test.SP = ServiceProvider{
EntityID: "https://sp.example.com/saml2/metadata",
Key: test.SPKey,
Certificate: test.SPCertificate,
MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"),
Expand Down
1 change: 1 addition & 0 deletions samlidp/samlidp_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ OwJlNCASPZRH/JmF8tX0hoHuAQ==
-----END CERTIFICATE-----
`)
test.SP = saml.ServiceProvider{
EntityID: "https://sp.example.com/saml2/metadata",
Key: test.SPKey,
Certificate: test.SPCertificate,
MetadataURL: mustParseURL("https://sp.example.com/saml2/metadata"),
Expand Down
6 changes: 6 additions & 0 deletions samlsp/new.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import (

// Options represents the parameters for creating a new middleware
type Options struct {
EntityID string
URL url.URL
Key *rsa.PrivateKey
Certificate *x509.Certificate
Expand Down Expand Up @@ -126,6 +127,7 @@ func DefaultServiceProvider(opts Options) saml.ServiceProvider {
}

return saml.ServiceProvider{
EntityID: opts.EntityID,
Key: opts.Key,
Certificate: opts.Certificate,
Intermediates: opts.Intermediates,
Expand Down Expand Up @@ -164,6 +166,10 @@ func New(opts Options) (*Middleware, error) {
opts.IDPMetadata = metadata
}

if opts.EntityID == "" && opts.IDPMetadataURL != nil {
opts.EntityID = opts.IDPMetadataURL.String()
}

m := &Middleware{
ServiceProvider: DefaultServiceProvider(opts),
Binding: "",
Expand Down
36 changes: 19 additions & 17 deletions service_provider.go
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,9 @@ type SignatureVerifier interface {
// See the example directory for an example of a web application using
// the service provider interface.
type ServiceProvider struct {
// Entity ID
EntityID string

// Key is the RSA private key we use to sign requests.
Key *rsa.PrivateKey

Expand Down Expand Up @@ -156,7 +159,7 @@ func (sp *ServiceProvider) Metadata() *EntityDescriptor {
}

return &EntityDescriptor{
EntityID: sp.MetadataURL.String(),
EntityID: sp.EntityID,
ValidUntil: validUntil,

SPSSODescriptors: []SPSSODescriptor{
Expand Down Expand Up @@ -316,7 +319,7 @@ func (sp *ServiceProvider) MakeAuthenticationRequest(idpURL string) (*AuthnReque
Version: "2.0",
Issuer: &Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: sp.MetadataURL.String(),
Value: sp.EntityID,
},
NameIDPolicy: &NameIDPolicy{
AllowCreate: &allowCreate,
Expand Down Expand Up @@ -655,12 +658,12 @@ func (sp *ServiceProvider) validateAssertion(assertion *Assertion, possibleReque

audienceRestrictionsValid := len(assertion.Conditions.AudienceRestrictions) == 0
for _, audienceRestriction := range assertion.Conditions.AudienceRestrictions {
if audienceRestriction.Audience.Value == sp.MetadataURL.String() {
if audienceRestriction.Audience.Value == sp.EntityID {
audienceRestrictionsValid = true
}
}
if !audienceRestrictionsValid {
return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.MetadataURL.String())
return fmt.Errorf("assertion Conditions AudienceRestriction does not contain %q", sp.EntityID)
}
return nil
}
Expand Down Expand Up @@ -800,7 +803,7 @@ func (sp *ServiceProvider) MakeLogoutRequest(idpURL, nameID string) (*LogoutRequ
Destination: idpURL,
Issuer: &Issuer{
Format: "urn:oasis:names:tc:SAML:2.0:nameid-format:entity",
Value: sp.MetadataURL.String(),
Value: sp.EntityID,
},
NameID: &NameID{
Format: sp.nameIDFormat(),
Expand Down Expand Up @@ -843,7 +846,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) erro
if err != nil {
return fmt.Errorf("unable to parse form: %v", err)
}

return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse"))
}

Expand All @@ -855,11 +858,11 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
}

var resp LogoutResponse

if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil {
return fmt.Errorf("cannot unmarshal response: %s", err)
}

if err := sp.validateLogoutResponse(&resp); err != nil {
return err
}
Expand All @@ -868,7 +871,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error
if err := doc.ReadFromBytes(rawResponseBuf); err != nil {
return err
}

responseEl := doc.Root()
if err = sp.validateSigned(responseEl); err != nil {
return err
Expand All @@ -887,20 +890,20 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
}

gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf))

decoder := xml.NewDecoder(gr)

var resp LogoutResponse

err = decoder.Decode(&resp)
if err != nil {
return fmt.Errorf("unable to flate decode: %s", err)
}

if err := sp.validateLogoutResponse(&resp); err != nil {
return err
}

doc := etree.NewDocument()
if _, err := doc.ReadFrom(gr); err != nil {
return err
Expand All @@ -914,7 +917,6 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str
return nil
}


// validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid.
func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
if resp.Destination != sp.SloURL.String() {
Expand All @@ -931,6 +933,6 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error {
if resp.Status.StatusCode.Value != StatusSuccess {
return fmt.Errorf("status code was not %s", StatusSuccess)
}

return nil
}
}
3 changes: 3 additions & 0 deletions service_provider_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ func TestSPCanSetAuthenticationNameIDFormat(t *testing.T) {
func TestSPCanProduceMetadata(t *testing.T) {
test := NewServiceProviderTest()
s := ServiceProvider{
EntityID: "https://example.com/saml2/metadata",
Key: test.Key,
Certificate: test.Certificate,
MetadataURL: mustParseURL("https://example.com/saml2/metadata"),
Expand Down Expand Up @@ -144,6 +145,7 @@ func TestSPCanProduceMetadata(t *testing.T) {
func TestCanProduceMetadataNoSigningKey(t *testing.T) {
test := NewServiceProviderTest()
s := ServiceProvider{
EntityID: "https://example.com/saml2/metadata",
MetadataURL: mustParseURL("https://example.com/saml2/metadata"),
AcsURL: mustParseURL("https://example.com/saml2/acs"),
IDPMetadata: &EntityDescriptor{},
Expand Down Expand Up @@ -171,6 +173,7 @@ func TestSPCanProduceRedirectRequest(t *testing.T) {
}
Clock = dsig.NewFakeClockAt(TimeNow())
s := ServiceProvider{
EntityID: "https://15661444.ngrok.io/saml2/metadata",
Key: test.Key,
Certificate: test.Certificate,
MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"),
Expand Down

0 comments on commit aa58653

Please sign in to comment.