Skip to content

Commit

Permalink
samlsp: move the setting and reading of cookies into an interface
Browse files Browse the repository at this point in the history
We’ve had a bunch of changes requesting the ability to customize
how cookies are set and it is getting a little messy. This change
moves the code to setting and reading cookies into two interfaces
which you can extend/customize.
  • Loading branch information
crewjam committed Jan 8, 2018
1 parent c9c2cbc commit 4a651a6
Show file tree
Hide file tree
Showing 4 changed files with 146 additions and 59 deletions.
105 changes: 105 additions & 0 deletions samlsp/cookie.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
package samlsp

import (
"net/http"
"strings"
"time"

"github.com/crewjam/saml"
)

// ClientState implements client side storage for state.
type ClientState interface {
SetState(w http.ResponseWriter, r *http.Request, id string, value string)
GetStates(r *http.Request) map[string]string
GetState(r *http.Request, id string) string
DeleteState(w http.ResponseWriter, r *http.Request, id string) error
}

// ClientToken implements client side storage for signed authorization tokens.
type ClientToken interface {
GetToken(r *http.Request) string
SetToken(w http.ResponseWriter, r *http.Request, value string, maxAge time.Duration)
}

const stateCookiePrefix = "saml_"
const defaultCookieName = "token"

// ClientCookies implements ClientState and ClientToken using cookies.
type ClientCookies struct {
ServiceProvider *saml.ServiceProvider
Name string
Domain string
Secure bool
}

// SetState stores the named state value by setting a cookie.
func (c ClientCookies) SetState(w http.ResponseWriter, r *http.Request, id string, value string) {
http.SetCookie(w, &http.Cookie{
Name: stateCookiePrefix + id,
Value: value,
MaxAge: int(saml.MaxIssueDelay.Seconds()),
HttpOnly: true,
Secure: c.Secure || r.URL.Scheme == "https",
Path: c.ServiceProvider.AcsURL.Path,
})
}

// GetStates returns the currently stored states by reading cookies.
func (c ClientCookies) GetStates(r *http.Request) map[string]string {
rv := map[string]string{}
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, stateCookiePrefix) {
continue
}
name := strings.TrimPrefix(cookie.Name, stateCookiePrefix)
rv[name] = cookie.Value
}
return rv
}

// GetState returns a single stored state by reading the cookies
func (c ClientCookies) GetState(r *http.Request, id string) string {
stateCookie, err := r.Cookie(stateCookiePrefix + id)
if err != nil {
return ""
}
return stateCookie.Value
}

// DeleteState removes the named stored state by clearing the corresponding cookie.
func (c ClientCookies) DeleteState(w http.ResponseWriter, r *http.Request, id string) error {
cookie, err := r.Cookie(stateCookiePrefix + id)
if err != nil {
return err
}
cookie.Value = ""
cookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
http.SetCookie(w, cookie)
return nil
}

// SetToken assigns the specified token by setting a cookie.
func (c ClientCookies) SetToken(w http.ResponseWriter, r *http.Request, value string, maxAge time.Duration) {
http.SetCookie(w, &http.Cookie{
Name: c.Name,
Domain: c.Domain,
Value: value,
MaxAge: int(maxAge.Seconds()),
HttpOnly: true,
Secure: c.Secure || r.URL.Scheme == "https",
Path: "/",
})
}

// GetToken returns the token by reading the cookie.
func (c ClientCookies) GetToken(r *http.Request) string {
cookie, err := r.Cookie(c.Name)
if err != nil {
return ""
}
return cookie.Value
}

var _ ClientState = ClientCookies{}
var _ ClientToken = ClientCookies{}
67 changes: 18 additions & 49 deletions samlsp/middleware.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,7 @@ import (
"crypto/x509"
"encoding/base64"
"encoding/xml"
"fmt"
"net/http"
"strings"
"time"

"github.com/crewjam/saml"
Expand Down Expand Up @@ -47,15 +45,11 @@ import (
type Middleware struct {
ServiceProvider saml.ServiceProvider
AllowIDPInitiated bool
CookieName string
CookieMaxAge time.Duration
CookieDomain string
CookieSecure bool
TokenMaxAge time.Duration
ClientState ClientState
ClientToken ClientToken
}

const defaultCookieMaxAge = time.Hour
const defaultCookieName = "token"

var jwtSigningMethod = jwt.SigningMethodHS256

func randomBytes(n int) []byte {
Expand Down Expand Up @@ -145,15 +139,7 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {
return
}

http.SetCookie(w, &http.Cookie{
Name: fmt.Sprintf("saml_%s", relayState),
Value: signedState,
MaxAge: int(saml.MaxIssueDelay.Seconds()),
HttpOnly: true,
Secure: m.CookieSecure || r.URL.Scheme == "https",
Path: m.ServiceProvider.AcsURL.Path,
})

m.ClientState.SetState(w, r, relayState, signedState)
if binding == saml.HTTPRedirectBinding {
redirectURL := req.Redirect(relayState)
w.Header().Add("Location", redirectURL.String())
Expand All @@ -178,16 +164,11 @@ func (m *Middleware) RequireAccount(handler http.Handler) http.Handler {

func (m *Middleware) getPossibleRequestIDs(r *http.Request) []string {
rv := []string{}
for _, cookie := range r.Cookies() {
if !strings.HasPrefix(cookie.Name, "saml_") {
continue
}
m.ServiceProvider.Logger.Printf("getPossibleRequestIDs: cookie: %s", cookie.String())

for _, value := range m.ClientState.GetStates(r) {
jwtParser := jwt.Parser{
ValidMethods: []string{jwtSigningMethod.Name},
}
token, err := jwtParser.Parse(cookie.Value, func(t *jwt.Token) (interface{}, error) {
token, err := jwtParser.Parse(value, func(t *jwt.Token) (interface{}, error) {
secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)
return secretBlock, nil
})
Expand All @@ -214,39 +195,37 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)

redirectURI := "/"
if r.Form.Get("RelayState") != "" {
stateCookie, err := r.Cookie(fmt.Sprintf("saml_%s", r.Form.Get("RelayState")))
if err != nil {
m.ServiceProvider.Logger.Printf("cannot find corresponding cookie: %s", fmt.Sprintf("saml_%s", r.Form.Get("RelayState")))
if relayState := r.Form.Get("RelayState"); relayState != "" {
stateValue := m.ClientState.GetState(r, relayState)
if stateValue == "" {
m.ServiceProvider.Logger.Printf("cannot find corresponding state: %s", relayState)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}

jwtParser := jwt.Parser{
ValidMethods: []string{jwtSigningMethod.Name},
}
state, err := jwtParser.Parse(stateCookie.Value, func(t *jwt.Token) (interface{}, error) {
state, err := jwtParser.Parse(stateValue, func(t *jwt.Token) (interface{}, error) {
return secretBlock, nil
})
if err != nil || !state.Valid {
m.ServiceProvider.Logger.Printf("Cannot decode state JWT: %s (%s)", err, stateCookie.Value)
m.ServiceProvider.Logger.Printf("Cannot decode state JWT: %s (%s)", err, stateValue)
http.Error(w, http.StatusText(http.StatusForbidden), http.StatusForbidden)
return
}
claims := state.Claims.(jwt.MapClaims)
redirectURI = claims["uri"].(string)

// delete the cookie
stateCookie.Value = ""
stateCookie.Expires = time.Unix(1, 0) // past time as close to epoch as possible, but not zero time.Time{}
http.SetCookie(w, stateCookie)
m.ClientState.DeleteState(w, r, relayState)
}

now := saml.TimeNow()
claims := AuthorizationToken{}
claims.Audience = m.ServiceProvider.Metadata().EntityID
claims.IssuedAt = now.Unix()
claims.ExpiresAt = now.Add(m.CookieMaxAge).Unix()
claims.ExpiresAt = now.Add(m.TokenMaxAge).Unix()
claims.NotBefore = now.Unix()
if sub := assertion.Subject; sub != nil {
if nameID := sub.NameID; nameID != nil {
Expand All @@ -265,23 +244,13 @@ func (m *Middleware) Authorize(w http.ResponseWriter, r *http.Request, assertion
}
}
}

signedToken, err := jwt.NewWithClaims(jwt.SigningMethodHS256,
claims).SignedString(secretBlock)
if err != nil {
panic(err)
}

http.SetCookie(w, &http.Cookie{
Name: m.CookieName,
Domain: m.CookieDomain,
Value: signedToken,
MaxAge: int(m.CookieMaxAge.Seconds()),
HttpOnly: true,
Secure: m.CookieSecure || r.URL.Scheme == "https",
Path: "/",
})

m.ClientToken.SetToken(w, r, signedToken, m.TokenMaxAge)
http.Redirect(w, r, redirectURI, http.StatusFound)
}

Expand All @@ -298,13 +267,13 @@ func (m *Middleware) IsAuthorized(r *http.Request) bool {
// SAML login flow. If the request is authorized, then the request context is
// ammended with a Context object.
func (m *Middleware) GetAuthorizationToken(r *http.Request) *AuthorizationToken {
cookie, err := r.Cookie(m.CookieName)
if err != nil {
tokenStr := m.ClientToken.GetToken(r)
if tokenStr == "" {
return nil
}

tokenClaims := AuthorizationToken{}
token, err := jwt.ParseWithClaims(cookie.Value, &tokenClaims, func(t *jwt.Token) (interface{}, error) {
token, err := jwt.ParseWithClaims(tokenStr, &tokenClaims, func(t *jwt.Token) (interface{}, error) {
secretBlock := x509.MarshalPKCS1PrivateKey(m.ServiceProvider.Key)
return secretBlock, nil
})
Expand Down
12 changes: 9 additions & 3 deletions samlsp/middleware_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,9 +75,14 @@ func (test *MiddlewareTest) SetUpTest(c *C) {
IDPMetadata: &saml.EntityDescriptor{},
Logger: logger.DefaultLogger,
},
CookieName: "ttt",
CookieMaxAge: time.Hour * 2,
TokenMaxAge: time.Hour * 2,
}
cookieStore := ClientCookies{
ServiceProvider: &test.Middleware.ServiceProvider,
Name: "ttt",
}
test.Middleware.ClientState = &cookieStore
test.Middleware.ClientToken = &cookieStore
err := xml.Unmarshal([]byte(test.IDPMetadata), &test.Middleware.ServiceProvider.IDPMetadata)
c.Assert(err, IsNil)
}
Expand Down Expand Up @@ -149,7 +154,8 @@ func (test *MiddlewareTest) TestRequireAccountNoCreds(c *C) {
}

func (test *MiddlewareTest) TestRequireAccountNoCredsSecure(c *C) {
test.Middleware.CookieSecure = true
cookieStore := test.Middleware.ClientState.(*ClientCookies)
cookieStore.Secure = true
handler := test.Middleware.RequireAccount(
http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
panic("not reached")
Expand Down
21 changes: 14 additions & 7 deletions samlsp/samlsp.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ import (
"github.com/crewjam/saml/logger"
)

const defaultTokenMaxAge = time.Hour

// Options represents the parameters for creating a new middleware
type Options struct {
URL url.URL
Expand All @@ -33,7 +35,6 @@ type Options struct {

// New creates a new Middleware
func New(opts Options) (*Middleware, error) {

metadataURL := opts.URL
metadataURL.Path = metadataURL.Path + "/saml/metadata"
acsURL := opts.URL
Expand All @@ -43,9 +44,9 @@ func New(opts Options) (*Middleware, error) {
logr = logger.DefaultLogger
}

cookieMaxAge := opts.CookieMaxAge
tokenMaxAge := opts.CookieMaxAge
if opts.CookieMaxAge == 0 {
cookieMaxAge = defaultCookieMaxAge
tokenMaxAge = defaultTokenMaxAge
}

m := &Middleware{
Expand All @@ -59,11 +60,17 @@ func New(opts Options) (*Middleware, error) {
ForceAuthn: &opts.ForceAuthn,
},
AllowIDPInitiated: opts.AllowIDPInitiated,
CookieName: defaultCookieName,
CookieMaxAge: cookieMaxAge,
CookieDomain: opts.URL.Host,
CookieSecure: opts.CookieSecure,
TokenMaxAge: tokenMaxAge,
}

cookieStore := ClientCookies{
ServiceProvider: &m.ServiceProvider,
Name: defaultCookieName,
Domain: opts.URL.Host,
Secure: opts.CookieSecure,
}
m.ClientState = &cookieStore
m.ClientToken = &cookieStore

// fetch the IDP metadata if needed.
if opts.IDPMetadataURL == nil {
Expand Down

0 comments on commit 4a651a6

Please sign in to comment.