diff --git a/main.go b/main.go index 287dc4894..700271eaa 100644 --- a/main.go +++ b/main.go @@ -78,6 +78,7 @@ func main() { flagSet.String("validate-url", "", "Access token validation endpoint") flagSet.String("scope", "", "OAuth scope specification") flagSet.String("approval-prompt", "force", "OAuth approval_prompt") + flagSet.String("upstream-auth", "", "What to pass in the Authorization header to the upstream. Only supported value: (id_token)") flagSet.String("signature-key", "", "GAP-Signature request signature key (algorithm:secretkey)") diff --git a/oauthproxy.go b/oauthproxy.go index 21e5dfc74..1b5ab9b16 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -66,6 +66,7 @@ type OAuthProxy struct { PassUserHeaders bool BasicAuthPassword string PassAccessToken bool + UpstreamAuth string CookieCipher *cookie.Cipher skipAuthRegex []string skipAuthPreflight bool @@ -163,7 +164,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { log.Printf("Cookie settings: name:%s secure(https):%v httponly:%v expiry:%s domain:%s refresh:%s", opts.CookieName, opts.CookieSecure, opts.CookieHttpOnly, opts.CookieExpire, opts.CookieDomain, refresh) var cipher *cookie.Cipher - if opts.PassAccessToken || (opts.CookieRefresh != time.Duration(0)) { + if opts.PassAccessToken || opts.UpstreamAuth != "" || (opts.CookieRefresh != time.Duration(0)) { var err error cipher, err = cookie.NewCipher(secretBytes(opts.CookieSecret)) if err != nil { @@ -202,6 +203,7 @@ func NewOAuthProxy(opts *Options, validator func(string) bool) *OAuthProxy { PassUserHeaders: opts.PassUserHeaders, BasicAuthPassword: opts.BasicAuthPassword, PassAccessToken: opts.PassAccessToken, + UpstreamAuth: opts.UpstreamAuth, SkipProviderButton: opts.SkipProviderButton, CookieCipher: cipher, templates: loadTemplates(opts.CustomTemplatesDir), @@ -698,6 +700,9 @@ func (p *OAuthProxy) Authenticate(rw http.ResponseWriter, req *http.Request) int if p.PassAccessToken && session.AccessToken != "" { req.Header["X-Forwarded-Access-Token"] = []string{session.AccessToken} } + if p.UpstreamAuth == "id_token" && session.IdToken != "" { + req.Header["Authorization"] = []string{"Bearer " + session.IdToken} + } if session.Email == "" { rw.Header().Set("GAP-Auth", session.User) } else { diff --git a/options.go b/options.go index 949fbba80..daf321919 100644 --- a/options.go +++ b/options.go @@ -61,6 +61,7 @@ type Options struct { SSLInsecureSkipVerify bool `flag:"ssl-insecure-skip-verify" cfg:"ssl_insecure_skip_verify"` SetXAuthRequest bool `flag:"set-xauthrequest" cfg:"set_xauthrequest"` SkipAuthPreflight bool `flag:"skip-auth-preflight" cfg:"skip_auth_preflight"` + UpstreamAuth string `flag:"upstream-auth" cfg:"upstream_auth"` // These options allow for other providers besides Google, with // potential overrides. diff --git a/providers/google.go b/providers/google.go index 66406bd2b..026e90e0e 100644 --- a/providers/google.go +++ b/providers/google.go @@ -142,6 +142,7 @@ func (p *GoogleProvider) Redeem(redirectURL, code string) (s *SessionState, err } s = &SessionState{ AccessToken: jsonResponse.AccessToken, + IdToken: jsonResponse.IdToken, ExpiresOn: time.Now().Add(time.Duration(jsonResponse.ExpiresIn) * time.Second).Truncate(time.Second), RefreshToken: jsonResponse.RefreshToken, Email: email, diff --git a/providers/oidc.go b/providers/oidc.go index 0c0fa52a9..b9cd741cc 100644 --- a/providers/oidc.go +++ b/providers/oidc.go @@ -66,6 +66,7 @@ func (p *OIDCProvider) Redeem(redirectURL, code string) (s *SessionState, err er s = &SessionState{ AccessToken: token.AccessToken, RefreshToken: token.RefreshToken, + IdToken: rawIDToken, ExpiresOn: token.Expiry, Email: claims.Email, } diff --git a/providers/session_state.go b/providers/session_state.go index 805c702f5..2184bab4b 100644 --- a/providers/session_state.go +++ b/providers/session_state.go @@ -11,6 +11,7 @@ import ( type SessionState struct { AccessToken string + IdToken string ExpiresOn time.Time RefreshToken string Email string @@ -35,6 +36,9 @@ func (s *SessionState) String() string { if s.RefreshToken != "" { o += " refresh_token:true" } + if s.IdToken != "" { + o += fmt.Sprintf(" id_token:%s", s.IdToken) + } return o + "}" } @@ -66,7 +70,13 @@ func (s *SessionState) EncryptedString(c *cookie.Cipher) (string, error) { return "", err } } - return fmt.Sprintf("%s|%s|%d|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r), nil + i := s.IdToken + if i != "" { + if i, err = c.Encrypt(i); err != nil { + return "", err + } + } + return fmt.Sprintf("%s|%s|%d|%s|%s", s.accountInfo(), a, s.ExpiresOn.Unix(), r, i), nil } func decodeSessionStatePlain(v string) (s *SessionState, err error) { @@ -90,8 +100,8 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } chunks := strings.Split(v, "|") - if len(chunks) != 4 { - err = fmt.Errorf("invalid number of fields (got %d expected 4)", len(chunks)) + if len(chunks) != 5 { + err = fmt.Errorf("invalid number of fields (got %d expected 5)", len(chunks)) return } @@ -115,5 +125,11 @@ func DecodeSessionState(v string, c *cookie.Cipher) (s *SessionState, err error) } } + if chunks[4] != "" { + if sessionState.IdToken, err = c.Decrypt(chunks[4]); err != nil { + return nil, err + } + } + return sessionState, nil }