From 625396c43dd2f27cac4c8866027e30761f9945d0 Mon Sep 17 00:00:00 2001 From: Mathieu Mailhos Date: Mon, 3 Feb 2020 06:13:31 +1100 Subject: [PATCH] fix(sp): no check for InResponseTo for if IDPInitiated is true (#259) --- service_provider.go | 39 ++++++++++++++++++++------------------- 1 file changed, 20 insertions(+), 19 deletions(-) diff --git a/service_provider.go b/service_provider.go index 7780d881..55435ccf 100644 --- a/service_provider.go +++ b/service_provider.go @@ -514,16 +514,18 @@ func (sp *ServiceProvider) ParseXMLResponse(decodedResponseXML []byte, possibleR return nil, retErr } - if sp.AllowIDPInitiated && len(possibleRequestIDs) == 0 { - possibleRequestIDs = append([]string{""}) - } - requestIDvalid := false - for _, possibleRequestID := range possibleRequestIDs { - if resp.InResponseTo == possibleRequestID { - requestIDvalid = true + + if sp.AllowIDPInitiated { + requestIDvalid = true + } else { + for _, possibleRequestID := range possibleRequestIDs { + if resp.InResponseTo == possibleRequestID { + requestIDvalid = true + } } } + if !requestIDvalid { retErr.PrivateErr = fmt.Errorf("`InResponseTo` does not match any of the possible request IDs (expected %v)", possibleRequestIDs) return nil, retErr @@ -843,7 +845,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseRequest(req *http.Request) erro if err != nil { return fmt.Errorf("unable to parse form: %v", err) } - + return sp.ValidateLogoutResponseForm(req.PostForm.Get("SAMLResponse")) } @@ -855,11 +857,11 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error } var resp LogoutResponse - + if err := xml.Unmarshal(rawResponseBuf, &resp); err != nil { return fmt.Errorf("cannot unmarshal response: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } @@ -868,7 +870,7 @@ func (sp *ServiceProvider) ValidateLogoutResponseForm(postFormData string) error if err := doc.ReadFromBytes(rawResponseBuf); err != nil { return err } - + responseEl := doc.Root() if err = sp.validateSigned(responseEl); err != nil { return err @@ -887,20 +889,20 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str } gr := flate.NewReader(bytes.NewBuffer(rawResponseBuf)) - + decoder := xml.NewDecoder(gr) - + var resp LogoutResponse - + err = decoder.Decode(&resp) if err != nil { return fmt.Errorf("unable to flate decode: %s", err) } - + if err := sp.validateLogoutResponse(&resp); err != nil { return err } - + doc := etree.NewDocument() if _, err := doc.ReadFrom(gr); err != nil { return err @@ -914,7 +916,6 @@ func (sp *ServiceProvider) ValidateLogoutResponseRedirect(queryParameterData str return nil } - // validateLogoutResponse validates the LogoutResponse fields. Returns a nil error if the LogoutResponse is valid. func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Destination != sp.SloURL.String() { @@ -931,6 +932,6 @@ func (sp *ServiceProvider) validateLogoutResponse(resp *LogoutResponse) error { if resp.Status.StatusCode.Value != StatusSuccess { return fmt.Errorf("status code was not %s", StatusSuccess) } - + return nil -} \ No newline at end of file +}