From 556df518fa0d8e433f1625f4d53d7594da2e2cef Mon Sep 17 00:00:00 2001 From: Mathieu Mailhos Date: Tue, 14 Jan 2020 11:30:56 +1100 Subject: [PATCH] feat(slo): add logout response request validation --- service_provider.go | 93 +++++++++++++++++++++++++++++++++++++-------- 1 file changed, 78 insertions(+), 15 deletions(-) diff --git a/service_provider.go b/service_provider.go index 9fe1d3e0..7780d881 100644 --- a/service_provider.go +++ b/service_provider.go @@ -833,37 +833,79 @@ func (sp *ServiceProvider) nameIDFormat() string { return nameIDFormat } -// ValidateLogoutResponse returns a nil error iff the logout request is valid. -func (sp *ServiceProvider) ValidateLogoutResponse(r *http.Request) error { - r.ParseForm() - rawResponseBuf, err := base64.StdEncoding.DecodeString(r.PostForm.Get("SAMLResponse")) +// ValidateLogoutResponseRequest validates the LogoutResponse content from the request +func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) error { + if data := req.URL.Query().Get("SAMLResponse"); data != "" { + return sp.ValidateLogoutResponseRedirect(data) + } + + err := req.ParseForm() + if err != nil { + return fmt.Errorf("unable to parse form: %v", err) + } + + return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse")) +} + +// ValidatePostLogoutResponse returns a nil error if the logout response is valid. +func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error { + rawResponseBuf, err := base64.StdEncoding.DecodeString(postFormData) if err != nil { return fmt.Errorf("unable to parse base64: %s", err) } - resp := LogoutResponse{} + var resp LogoutResponse + if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil { return fmt.Errorf("cannot unmarshal response: %s", err) } - if resp.Destination != sp.SloURL.String() { - return fmt.Errorf("`Destination` does not match SloURL (expected %q)", sp.SloURL.String()) + + if err := sp.validateLogoutResponse(&resp); err != nil { + return err } - now := time.Now() - if resp.IssueInstant.Add(MaxIssueDelay).Before(now) { - return fmt.Errorf("issueInstant expired at %s", resp.IssueInstant.Add(MaxIssueDelay)) + doc := etree.NewDocument() + if err := doc.ReadFromBytes(rawResponseBuf); err != nil { + return err } - if resp.Issuer.Value != sp.IDPMetadata.EntityID { - return fmt.Errorf("issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID) + + responseEl := doc.Root() + if err = sp.validateSigned(responseEl); err != nil { + return err } - if resp.Status.StatusCode.Value != StatusSuccess { - return fmt.Errorf("status code was not %s", StatusSuccess) + + return nil +} + +// ValidateRedirectLogoutResponse returns a nil error if the logout response is valid. +// URL Binding appears to be gzip / flate encoded +// See https://www.oasis-open.org/committees/download.php/20645/sstc-saml-tech-overview-2%200-draft-10.pdf 6.6 +func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData string) error { + rawResponseBuf, err := base64.StdEncoding.DecodeString(queryParameterData) + if err != nil { + return fmt.Errorf("unable to parse base64: %s", err) } + gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf)) + + decoder := xml.NewDecoder(gr) + + var resp LogoutResponse + + err = decoder.Decode(&resp) + if err != nil { + return fmt.Errorf("unable to flate decode: %s", err) + } + + if err := sp.validateLogoutResponse(&resp); err != nil { + return err + } + doc := etree.NewDocument() - if err := doc.ReadFromBytes(rawResponseBuf); err != nil { + if _, err := doc.ReadFrom(gr); err != nil { return err } + responseEl := doc.Root() if err = sp.validateSigned(responseEl); err != nil { return err @@ -871,3 +913,24 @@ func (sp *ServiceProvider) ValidateLogoutResponse(r *http.Request) error { return nil } + + +// validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid. +func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { + if resp.Destination != sp.SloURL.String() { + return fmt.Errorf("`Destination` does not match SloURL (expected %q)", sp.SloURL.String()) + } + + now := time.Now() + if resp.IssueInstant.Add(MaxIssueDelay).Before(now) { + return fmt.Errorf("issueInstant expired at %s", resp.IssueInstant.Add(MaxIssueDelay)) + } + if resp.Issuer.Value != sp.IDPMetadata.EntityID { + return fmt.Errorf("issuer does not match the IDP metadata (expected %q)", sp.IDPMetadata.EntityID) + } + if resp.Status.StatusCode.Value != StatusSuccess { + return fmt.Errorf("status code was not %s", StatusSuccess) + } + + return nil +} \ No newline at end of file