diff --git a/service_provider.go b/service_provider.go index c5801c0d..a2a0551a 100644 --- a/service_provider.go +++ b/service_provider.go @@ -378,6 +378,40 @@ func (ivr *InvalidResponseError) Error() string { return fmt.Sprintf("Authentication failed") } +func responseIsSigned(response *etree.Document) (bool, error) { + signatureElement, err := findChild(response.Root(), "http://www.w3.org/2000/09/xmldsig#", "Signature") + if err != nil { + return false, err + } + return signatureElement != nil, nil +} + +// validateDestination validates the Destination attribute. +// If the response is signed, the Destination is required to be present. +func (sp *ServiceProvider) validateDestination(response []byte, responseDom *Response) error { + responseXml := etree.NewDocument() + err := responseXml.ReadFromBytes(response) + if err != nil { + return err + } + + signed, err := responseIsSigned(responseXml) + if err != nil { + return err + } + + + // Compare if the response is signed OR the Destination is provided. + // (Even if the response is not signed, if the Destination is set it must match.) + if signed || responseDom.Destination != "" { + if responseDom.Destination != sp.AcsURL.String() { + return fmt.Errorf("`Destination` does not match AcsURL (expected %q, actual %q)", sp.AcsURL.String(), responseDom.Destination) + } + } + + return nil +} + // ParseResponse extracts the SAML IDP response received in req, validates // it, and returns the verified attributes of the request. // @@ -409,8 +443,9 @@ func (sp *ServiceProvider) ParseResponse(req *http.Request, possibleRequestIDs [ retErr.PrivateErr = fmt.Errorf("cannot unmarshal response: %s", err) return nil, retErr } - if resp.Destination != sp.AcsURL.String() { - retErr.PrivateErr = fmt.Errorf("`Destination` does not match AcsURL (expected %q)", sp.AcsURL.String()) + + if err := sp.validateDestination(rawResponseBuf, &resp); err != nil { + retErr.PrivateErr = err return nil, retErr } diff --git a/service_provider_test.go b/service_provider_test.go index bb0e0e1e..f18ba5cc 100644 --- a/service_provider_test.go +++ b/service_provider_test.go @@ -17,6 +17,7 @@ import ( "crypto/x509" + "github.com/beevik/etree" . "gopkg.in/check.v1" ) @@ -642,6 +643,109 @@ func (test *ServiceProviderTest) TestCanParseResponse(c *C) { }) } +func (test *ServiceProviderTest) replaceDestination(newDestination string) { + newStr := "" + if newDestination != "" { + newStr = `Destination="` + newDestination + `"` + } + test.SamlResponse = strings.Replace(test.SamlResponse, `Destination="https://15661444.ngrok.io/saml2/acs"`, newStr, 1) +} + +func (test *ServiceProviderTest) TestCanProcessResponseWithoutDestination(c *C) { + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + c.Assert(err, IsNil) + + req := http.Request{PostForm: url.Values{}} + test.replaceDestination("") + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString([]byte(test.SamlResponse))) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + c.Assert(err, Equals, nil) +} + +func (test *ServiceProviderTest) responseDom() (doc *etree.Document) { + doc = etree.NewDocument() + doc.ReadFromString(test.SamlResponse) + return doc +} + +func addSignatureToDocument(doc *etree.Document) *etree.Document { + responseEl := doc.FindElement("//Response") + signatureEl := doc.CreateElement("xmldsig:Signature") + signatureEl.CreateAttr("xmlns:xmldsig", "http://www.w3.org/2000/09/xmldsig#") + responseEl.AddChild(signatureEl) + return doc +} + +func removeDestinationFromDocument(doc *etree.Document) *etree.Document { + responseEl := doc.FindElement("//Response") + responseEl.RemoveAttr("Destination") + return doc +} + +func (test *ServiceProviderTest) TestMismatchedDestinationsWithSignaturePresent(c *C) { + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + c.Assert(err, IsNil) + + req := http.Request{PostForm: url.Values{}} + test.replaceDestination("https://wrong/saml2/acs") + bytes, _ := addSignatureToDocument(test.responseDom()).WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + c.Assert(err.(*InvalidResponseError).PrivateErr.Error(), Equals, "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")") +} + +func (test *ServiceProviderTest) TestMismatchedDestinationsWithNoSignaturePresent(c *C) { + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + c.Assert(err, IsNil) + + req := http.Request{PostForm: url.Values{}} + test.replaceDestination("https://wrong/saml2/acs") + bytes, _ := test.responseDom().WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + c.Assert(err.(*InvalidResponseError).PrivateErr.Error(), Equals, "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"https://wrong/saml2/acs\")") +} + +func (test *ServiceProviderTest) TestMissingDestinationWithSignaturePresent(c *C) { + s := ServiceProvider{ + Key: test.Key, + Certificate: test.Certificate, + MetadataURL: mustParseURL("https://15661444.ngrok.io/saml2/metadata"), + AcsURL: mustParseURL("https://15661444.ngrok.io/saml2/acs"), + IDPMetadata: &EntityDescriptor{}, + } + err := xml.Unmarshal([]byte(test.IDPMetadata), &s.IDPMetadata) + c.Assert(err, IsNil) + + req := http.Request{PostForm: url.Values{}} + test.replaceDestination("") + bytes, _ := addSignatureToDocument(test.responseDom()).WriteToBytes() + req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString(bytes)) + _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) + c.Assert(err.(*InvalidResponseError).PrivateErr.Error(), Equals, "`Destination` does not match AcsURL (expected \"https://15661444.ngrok.io/saml2/acs\", actual \"\")") +} + func (test *ServiceProviderTest) TestInvalidResponses(c *C) { s := ServiceProvider{ Key: test.Key, @@ -662,12 +766,6 @@ func (test *ServiceProviderTest) TestInvalidResponses(c *C) { _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) c.Assert(err.(*InvalidResponseError).PrivateErr, ErrorMatches, "cannot unmarshal response: expected element type but have ") - s.AcsURL = mustParseURL("https://wrong/saml2/acs") - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString([]byte(test.SamlResponse))) - _, err = s.ParseResponse(&req, []string{"id-9e61753d64e928af5a7a341a97f420c9"}) - c.Assert(err.(*InvalidResponseError).PrivateErr.Error(), Equals, "`Destination` does not match AcsURL (expected \"https://wrong/saml2/acs\")") - s.AcsURL = mustParseURL("https://15661444.ngrok.io/saml2/acs") - req.PostForm.Set("SAMLResponse", base64.StdEncoding.EncodeToString([]byte(test.SamlResponse))) _, err = s.ParseResponse(&req, []string{"wrongRequestID"}) c.Assert(err.(*InvalidResponseError).PrivateErr.Error(), Equals, "`InResponseTo` does not match any of the possible request IDs (expected [wrongRequestID])")