From a3e934d7b2d315ed961900bc2b8b6b0a55049546 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 27 Jan 2018 10:14:19 +0000 Subject: [PATCH 1/4] Add Authorization header flags --- main.go | 2 ++ oauthproxy.go | 12 +++++++++++- options.go | 4 ++++ providers/google.go | 1 + providers/oidc.go | 1 + providers/session_state.go | 1 + 6 files changed, 20 insertions(+), 1 deletion(-) diff --git a/main.go b/main.go index 287dc4894..68be1527d 100644 --- a/main.go +++ b/main.go @@ -37,6 +37,8 @@ func main() { flagSet.String("basic-auth-password", "", "the password to set when passing the HTTP Basic Auth header") flagSet.Bool("pass-access-token", false, "pass OAuth access_token to upstream via X-Forwarded-Access-Token header") flagSet.Bool("pass-host-header", true, "pass the request Host Header to upstream") + flagSet.Bool("pass-authorization-header", false, "pass the Authorization Header to upstream") + flagSet.Bool("set-authorization-header", false, "set Authorization response headers (useful in Nginx auth_request mode)") flagSet.Var(&skipAuthRegex, "skip-auth-regex", "bypass authentication for requests path's that match (may be given multiple times)") flagSet.Bool("skip-provider-button", false, "will skip sign-in-page to directly reach the next step: oauth/start") flagSet.Bool("skip-auth-preflight", false, "will skip authentication for OPTIONS requests") diff --git a/oauthproxy.go b/oauthproxy.go index 21e5dfc74..6c48040a3 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -66,6 +66,8 @@ type OAuthProxy struct { PassUserHeaders bool BasicAuthPassword string PassAccessToken bool + SetAuthorization bool + PassAuthorization bool CookieCipher *cookie.Cipher skipAuthRegex []string skipAuthPreflight bool @@ -163,7 +165,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) var cipher *cookie.Cipher - if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { + if opts.PassAccessToken || opts.SetAuthorization || opts.PassAuthorization || (opts.CookieRefresh != time.Duration(0)) { var err error cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) if err != nil { @@ -202,6 +204,8 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { PassUserHeaders: opts.PassUserHeaders, BasicAuthPassword: opts.BasicAuthPassword, PassAccessToken: opts.PassAccessToken, + SetAuthorization: opts.SetAuthorization, + PassAuthorization: opts.PassAuthorization, SkipProviderButton: opts.SkipProviderButton, CookieCipher: cipher, templates: loadTemplates(opts.CustomTemplatesDir), @@ -698,6 +702,12 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int if p.PassAccessToken && session.AccessToken != "" { req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} } + if p.PassAuthorization && session.IdToken != "" { + req.Header["Authorization"] = []string{fmt.Sprintf("Bearer %s", session.IdToken)} + } + if p.SetAuthorization && session.IdToken != "" { + rw.Header().Set("Authorization", fmt.Sprintf("Bearer %s", session.IdToken)) + } if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { diff --git a/options.go b/options.go index 949fbba80..70def5d52 100644 --- a/options.go +++ b/options.go @@ -60,6 +60,8 @@ type Options struct { PassUserHeaders bool `flag:"pass-user-headers" cfg:"pass_user_headers"` SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` + SetAuthorization bool `flag:"set-authorization-header" cfg:"set_authorization_header"` + PassAuthorization bool `flag:"pass-authorization-header" cfg:"pass_authorization_header"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` // These options allow for other providers besides Google, with @@ -110,6 +112,8 @@ func NewOptions() *Options { PassUserHeaders: true, PassAccessToken: false, PassHostHeader: true, + SetAuthorization: false, + PassAuthorization: false, ApprovalPrompt: "force", RequestLogging: true, RequestLoggingFormat: defaultRequestLoggingFormat, diff --git a/providers/google.go b/providers/google.go index 66406bd2b..026e90e0e 100644 --- a/providers/google.go +++ b/providers/google.go @@ -142,6 +142,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err } s = &SessionState{ AccessToken: jsonResponse.AccessToken, + IdToken: jsonResponse.IdToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), RefreshToken: jsonResponse.RefreshToken, Email: email, diff --git a/providers/oidc.go b/providers/oidc.go index 0c0fa52a9..58adca077 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -65,6 +65,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er s = &SessionState{ AccessToken: token.AccessToken, + IdToken: rawIDToken, RefreshToken: token.RefreshToken, ExpiresOn: token.Expiry, Email: claims.Email, diff --git a/providers/session_state.go b/providers/session_state.go index 805c702f5..648208ad7 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -11,6 +11,7 @@ import ( type SessionState struct { AccessToken string + IdToken string ExpiresOn time.Time RefreshToken string Email string From 98c751e918160744c117a891bb4f71402b7e6e37 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 27 Jan 2018 10:53:17 +0000 Subject: [PATCH 2/4] Update sessions state --- providers/session_state.go | 27 +++++++++++++++++++++------ providers/session_state_test.go | 7 +++++-- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/providers/session_state.go b/providers/session_state.go index 648208ad7..559c18163 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -30,6 +30,9 @@ func (s *SessionState) String() string { if s.AccessToken != "" { o += " token:true" } + if s.IdToken != "" { + o += " id_token:true" + } if !s.ExpiresOn.IsZero() { o += fmt.Sprintf(" expires:%s", s.ExpiresOn) } @@ -61,13 +64,19 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { return "", err } } + i := s.IdToken + if i != "" { + if i, err = c.Encrypt(i); err != nil { + return "", err + } + } r := s.RefreshToken if r != "" { if r, err = c.Encrypt(r); err != nil { return "", err } } - return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil + return fmt.Sprintf("%s|%s|%s|%d|%s", s.accountInfo(), a, i, s.ExpiresOn.Unix(), r), nil } func decodeSessionStatePlain(v string) (s *SessionState, err error) { @@ -91,8 +100,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } chunks := strings.Split(v, "|") - if len(chunks) != 4 { - err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) + if len(chunks) != 5 { + err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) return } @@ -107,11 +116,17 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } } - ts, _ := strconv.Atoi(chunks[2]) + if chunks[2] != "" { + if sessionState.IdToken, err = c.Decrypt(chunks[2]); err != nil { + return nil, err + } + } + + ts, _ := strconv.Atoi(chunks[3]) sessionState.ExpiresOn = time.Unix(int64(ts), 0) - if chunks[3] != "" { - if sessionState.RefreshToken, err = c.Decrypt(chunks[3]); err != nil { + if chunks[4] != "" { + if sessionState.RefreshToken, err = c.Decrypt(chunks[4]); err != nil { return nil, err } } diff --git a/providers/session_state_test.go b/providers/session_state_test.go index d3cc8f881..c3b275d16 100644 --- a/providers/session_state_test.go +++ b/providers/session_state_test.go @@ -21,12 +21,13 @@ func TestSessionStateSerialization(t *testing.T) { s := &SessionState{ Email: "user@domain.com", AccessToken: "token1234", + IdToken: "rawtoken1234", ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour), RefreshToken: "refresh4321", } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 3, strings.Count(encoded, "|")) + assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) @@ -34,6 +35,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, "user", ss.User) assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.AccessToken, ss.AccessToken) + assert.Equal(t, s.IdToken, ss.IdToken) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.Equal(t, s.RefreshToken, ss.RefreshToken) @@ -45,6 +47,7 @@ func TestSessionStateSerialization(t *testing.T) { assert.Equal(t, s.Email, ss.Email) assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix()) assert.NotEqual(t, s.AccessToken, ss.AccessToken) + assert.NotEqual(t, s.IdToken, ss.IdToken) assert.NotEqual(t, s.RefreshToken, ss.RefreshToken) } @@ -62,7 +65,7 @@ func TestSessionStateSerializationWithUser(t *testing.T) { } encoded, err := s.EncodeSessionState(c) assert.Equal(t, nil, err) - assert.Equal(t, 3, strings.Count(encoded, "|")) + assert.Equal(t, 4, strings.Count(encoded, "|")) ss, err := DecodeSessionState(encoded, c) t.Logf("%#v", ss) From 57824fb189d8dea2de8761d28dc48f7836925fe9 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Sat, 27 Jan 2018 22:48:52 +0000 Subject: [PATCH 3/4] Split large cookies --- oauthproxy.go | 104 ++++++++++++++++++++++++++++++++++++++++----- oauthproxy_test.go | 79 ++++++++++++++++++++++++++++++++-- 2 files changed, 168 insertions(+), 15 deletions(-) diff --git a/oauthproxy.go b/oauthproxy.go index 6c48040a3..46d77302f 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -258,15 +258,92 @@ func (p *OAuthProxy) redeemCode(host, code string) (s *providers.SessionState, e return } -func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { +func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expiration time.Duration, now time.Time) []*http.Cookie { if value != "" { value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) - if len(value) > 4096 { - // Cookies cannot be larger than 4kb - log.Printf("WARNING - Cookie Size: %d bytes", len(value)) + } + c := p.makeCookie(req, p.CookieName, value, expiration, now) + if len(c.Value) > 4096 { + return splitCookie(c) + } + return []*http.Cookie{c} +} + +func copyCookie(c *http.Cookie) *http.Cookie { + return &http.Cookie{ + Name: c.Name, + Value: c.Value, + Path: c.Path, + Domain: c.Domain, + Expires: c.Expires, + RawExpires: c.RawExpires, + MaxAge: c.MaxAge, + Secure: c.Secure, + HttpOnly: c.HttpOnly, + Raw: c.Raw, + Unparsed: c.Unparsed, + } +} + +func splitCookie(c *http.Cookie) []*http.Cookie { + if len(c.Value) < 3840 { + return []*http.Cookie{c} + } + cookies := []*http.Cookie{} + valueBytes := []byte(c.Value) + count := 0 + for len(valueBytes) > 0 { + new := copyCookie(c) + new.Name = fmt.Sprintf("%s-%d", c.Name, count) + count++ + if len(valueBytes) < 3840 { + new.Value = string(valueBytes) + valueBytes = []byte{} + } else { + newValue := valueBytes[:3840] + valueBytes = valueBytes[3840:] + new.Value = string(newValue) + } + cookies = append(cookies, new) + } + return cookies +} + +func joinCookies(cookies []*http.Cookie) (*http.Cookie, error) { + if len(cookies) == 0 { + return nil, fmt.Errorf("Could not load cookie.") + } + if len(cookies) == 1 { + return cookies[0], nil + } + c := copyCookie(cookies[0]) + for i := 1; i < len(cookies); i++ { + c.Value += cookies[i].Value + } + c.Name = strings.TrimRight(c.Name, "-0") + return c, nil +} + +func loadCookie(req *http.Request, cookieName string) (*http.Cookie, error) { + c, err := req.Cookie(cookieName) + if err == nil { + return c, nil + } + cookies := []*http.Cookie{} + err = nil + count := 0 + for err == nil { + var c *http.Cookie + c, err = req.Cookie(fmt.Sprintf("%s-%d", cookieName, count)) + if err == nil { + cookies = append(cookies, c) + count++ } } - return p.makeCookie(req, p.CookieName, value, expiration, now) + if len(cookies) == 0 { + return nil, fmt.Errorf("Could not find cookie %s", cookieName) + } + return joinCookies(cookies) } func (p *OAuthProxy) MakeCSRFCookie(req *http.Request, value string, expiration time.Duration, now time.Time) *http.Cookie { @@ -296,6 +373,7 @@ func (p *OAuthProxy) makeCookie(req *http.Request, name string, value string, ex } func (p *OAuthProxy) ClearCSRFCookie(rw http.ResponseWriter, req *http.Request) { + http.SetCookie(rw, p.MakeCSRFCookie(req, "", time.Hour*-1, time.Now())) } @@ -304,24 +382,28 @@ func (p *OAuthProxy) SetCSRFCookie(rw http.ResponseWriter, req *http.Request, va } func (p *OAuthProxy) ClearSessionCookie(rw http.ResponseWriter, req *http.Request) { - clr := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) - http.SetCookie(rw, clr) + cookies := p.MakeSessionCookie(req, "", time.Hour*-1, time.Now()) + for _, clr := range cookies { + http.SetCookie(rw, clr) + } // ugly hack because default domain changed - if p.CookieDomain == "" { - clr2 := *clr + if p.CookieDomain == "" && len(cookies) > 0 { + clr2 := *cookies[0] clr2.Domain = req.Host http.SetCookie(rw, &clr2) } } func (p *OAuthProxy) SetSessionCookie(rw http.ResponseWriter, req *http.Request, val string) { - http.SetCookie(rw, p.MakeSessionCookie(req, val, p.CookieExpire, time.Now())) + for _, c := range p.MakeSessionCookie(req, val, p.CookieExpire, time.Now()) { + http.SetCookie(rw, c) + } } func (p *OAuthProxy) LoadCookiedSession(req *http.Request) (*providers.SessionState, time.Duration, error) { var age time.Duration - c, err := req.Cookie(p.CookieName) + c, err := loadCookie(req, p.CookieName) if err != nil { // always http.ErrNoCookie return nil, age, fmt.Errorf("Cookie %q not present", p.CookieName) diff --git a/oauthproxy_test.go b/oauthproxy_test.go index 1e6b3140d..f313c8910 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -6,6 +6,7 @@ import ( "io" "io/ioutil" "log" + "math/rand" "net" "net/http" "net/http/httptest" @@ -92,6 +93,73 @@ func TestRobotsTxt(t *testing.T) { assert.Equal(t, "User-agent: *\nDisallow: /", rw.Body.String()) } +func randomString(length int) string { + charset := "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789" + seededRand := rand.New(rand.NewSource(time.Now().UnixNano())) + b := make([]byte, length) + for i := range b { + b[i] = charset[seededRand.Intn(len(charset))] + } + return string(b) +} + +func TestSplitCookie(t *testing.T) { + c1 := &http.Cookie{ + Name: "cookie-name", + Value: randomString(5120), + Path: "/", + Domain: "foo.bar", + HttpOnly: true, + Secure: true, + Expires: time.Now(), + } + cookies := splitCookie(c1) + assert.Equal(t, 2, len(cookies)) + + assert.Equal(t, c1.Name+"-0", cookies[0].Name) + assert.Equal(t, c1.Name+"-1", cookies[1].Name) + + assert.Equal(t, 3840, len(cookies[0].Value)) + assert.Equal(t, 5120-3840, len(cookies[1].Value)) + + c2 := &http.Cookie{ + Name: "cookie-name", + Value: randomString(3000), + Path: "/", + Domain: "foo.bar", + HttpOnly: true, + Secure: true, + Expires: time.Now(), + } + + cookies2 := splitCookie(c2) + assert.Equal(t, 1, len(cookies2)) + + assert.Equal(t, c2.Name, cookies2[0].Name) + assert.Equal(t, c2.Value, cookies2[0].Value) +} + +func TestJoinCookies(t *testing.T) { + c1 := &http.Cookie{ + Name: "cookie-name", + Value: randomString(5120), + Path: "/", + Domain: "foo.bar", + HttpOnly: true, + Secure: true, + Expires: time.Now(), + } + // Split Cookies + cookies := splitCookie(c1) + assert.Equal(t, 2, len(cookies)) + + // join cookies should be the ivnerse + c2, _ := joinCookies(cookies) + + assert.Equal(t, c1.Name, c2.Name) + assert.Equal(t, c1.Value, c2.Value) +} + type TestProvider struct { *providers.ProviderData EmailAddress string @@ -504,7 +572,7 @@ func NewProcessCookieTestWithDefaults() *ProcessCookieTest { }) } -func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) *http.Cookie { +func (p *ProcessCookieTest) MakeCookie(value string, ref time.Time) []*http.Cookie { return p.proxy.MakeSessionCookie(p.req, value, p.opts.CookieExpire, ref) } @@ -513,7 +581,9 @@ func (p *ProcessCookieTest) SaveSession(s *providers.SessionState, ref time.Time if err != nil { return err } - p.req.AddCookie(p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref)) + for _, c := range p.proxy.MakeSessionCookie(p.req, value, p.proxy.CookieExpire, ref) { + p.req.AddCookie(c) + } return nil } @@ -802,8 +872,9 @@ func (st *SignatureTest) MakeRequestWithExpectedKey(method, body, key string) { if err != nil { panic(err) } - cookie := proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) - req.AddCookie(cookie) + for _, c := range proxy.MakeSessionCookie(req, value, proxy.CookieExpire, time.Now()) { + req.AddCookie(c) + } // This is used by the upstream to validate the signature. st.authenticator.auth = hmacauth.NewHmacAuth( crypto.SHA1, []byte(key), SignatureHeader, SignatureHeaders) From 202f75655e240887de547b5eefb5dfb480d338b3 Mon Sep 17 00:00:00 2001 From: Joel Speed Date: Wed, 20 Jun 2018 14:56:13 +0100 Subject: [PATCH 4/4] Fix cookie split should account for cookie name --- oauthproxy.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauthproxy.go b/oauthproxy.go index 46d77302f..b081484e4 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -263,7 +263,7 @@ func (p *OAuthProxy) MakeSessionCookie(req *http.Request, value string, expirati value = cookie.SignedValue(p.CookieSeed, p.CookieName, value, now) } c := p.makeCookie(req, p.CookieName, value, expiration, now) - if len(c.Value) > 4096 { + if len(c.Value) > 4096-len(p.CookieName) { return splitCookie(c) } return []*http.Cookie{c}