diff --git a/oauthproxy.go b/oauthproxy.go index 62084cbf3..936fee508 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -249,13 +249,16 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh) ok = p.Validator(email) log.Printf("re-validating %s valid:%v", email, ok) - if ok { - ok = p.provider.ValidateToken(access_token) - log.Printf("re-validating access token. valid:%v", ok) + if !ok { + return } - if ok { + if ok, new_token := p.provider.ValidateToken(access_token); ok { + if new_token != "" { + value = new_token + } p.SetCookie(rw, req, value) } + log.Printf("re-validating access token. valid:%v", ok) } } return diff --git a/oauthproxy_test.go b/oauthproxy_test.go index ed02b880d..14082607f 100644 --- a/oauthproxy_test.go +++ b/oauthproxy_test.go @@ -98,8 +98,8 @@ func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (strin return tp.EmailAddress, nil } -func (tp *TestProvider) ValidateToken(access_token string) bool { - return tp.ValidToken +func (tp *TestProvider) ValidateToken(access_token string) (bool, string) { + return tp.ValidToken, "" } type PassAccessTokenTest struct { diff --git a/providers/github.go b/providers/github.go index b138af6bf..030469b6c 100644 --- a/providers/github.go +++ b/providers/github.go @@ -188,6 +188,6 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri return "", nil } -func (p *GitHubProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) +func (p *GitHubProvider) ValidateToken(access_token string) (bool, string) { + return validateToken(p, access_token, nil), "" } diff --git a/providers/google.go b/providers/google.go index 40a62283c..e326fdead 100644 --- a/providers/google.go +++ b/providers/google.go @@ -7,6 +7,7 @@ import ( "errors" "fmt" "io/ioutil" + "log" "net/http" "net/url" "strings" @@ -81,8 +82,27 @@ func jwtDecodeSegment(seg string) ([]byte, error) { return base64.URLEncoding.DecodeString(seg) } -func (p *GoogleProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) +func (p *GoogleProvider) ValidateToken(access_token string) (ok bool, new_token string) { + var orig_token string + var refresh_token string + + if components := strings.Split(access_token, " "); len(components) != 2 { + return + } else { + orig_token, refresh_token = components[0], components[1] + } + + if ok = validateToken(p, orig_token, nil); ok == true { + return + } + log.Printf("original token expired; redeeming refresh token") + if renewed, err := p.redeemRefreshToken(refresh_token); err == nil { + new_token = fmt.Sprintf("%s %s", renewed, refresh_token) + ok = true + } else { + log.Printf("redeeming refresh token failed: %v", err) + } + return } func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) { @@ -128,7 +148,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st return } - token, err = p.redeemRefreshToken(jsonResponse.RefreshToken) + token = fmt.Sprintf("%s %s", jsonResponse.AccessToken, jsonResponse.RefreshToken) return } diff --git a/providers/google_test.go b/providers/google_test.go index 755164099..c9c8a1294 100644 --- a/providers/google_test.go +++ b/providers/google_test.go @@ -4,6 +4,7 @@ import ( "encoding/base64" "encoding/json" "github.com/bmizerany/assert" + "net/http/httptest" "net/url" "testing" ) @@ -19,6 +20,44 @@ func newGoogleProvider() *GoogleProvider { Scope: ""}) } +func testGoogleRedeemBackend(data *ProviderData, redirect_uri string, code string, payload string) (server *httptest.Server) { + path := "/oauth2/v3/token" + form := url.Values{ + "redirect_uri": {redirect_uri}, + "client_id": {"0"}, + "client_secret": {""}, + "code": {code}, + "grant_type": {"authorization_code"}, + } + server = NewTestPostBackend(path, form, payload) + data.RedeemUrl, _ = url.Parse(server.URL) + data.RedeemUrl.Path = path + return +} + +func testGoogleRedeemRefreshTokenBackend(data *ProviderData, refresh_token string, payload string) (server *httptest.Server) { + path := "/oauth2/v3/token" + form := url.Values{ + "client_id": {"0"}, + "client_secret": {""}, + "refresh_token": {refresh_token}, + "grant_type": {"refresh_token"}, + } + server = NewTestPostBackend(path, form, payload) + data.RedeemUrl, _ = url.Parse(server.URL) + data.RedeemUrl.Path = path + return +} + +func testGoogleValidateTokenBackend(data *ProviderData, access_token string, payload string) (server *httptest.Server) { + path := "/oauth2/v1/tokeninfo" + query := "access_token=" + access_token + server = NewTestQueryBackend(path, query, payload) + data.ValidateUrl, _ = url.Parse(server.URL) + data.ValidateUrl.Path = path + return +} + func TestGoogleProviderDefaults(t *testing.T) { p := newGoogleProvider() assert.NotEqual(t, nil, p) @@ -126,3 +165,66 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) { assert.Equal(t, "", email) assert.NotEqual(t, nil, err) } + +func TestGoogleProviderReedeem(t *testing.T) { + p, redirect_uri, code := newGoogleProvider(), "/redirect", "my code" + payload := `{"access_token":"access_token",` + + ` "expires_in":3920,` + + ` "token_type":"Bearer",` + + ` "refresh_token":"refresh_token"}` + server := testGoogleRedeemBackend(p.Data(), redirect_uri, code, payload) + defer server.Close() + response, token, err := p.Redeem(redirect_uri, code) + assert.Equal(t, nil, err) + assert.Equal(t, payload, string(response)) + assert.Equal(t, "access_token refresh_token", token) +} + +func TestGoogleProviderReedeemRefreshToken(t *testing.T) { + p, refresh_token := newGoogleProvider(), "refresh_token" + payload := `{"access_token":"new_access_token",` + + ` "expires_in":3920,` + + ` "token_type":"Bearer"}` + server := testGoogleRedeemRefreshTokenBackend(p.Data(), refresh_token, payload) + defer server.Close() + token, err := p.redeemRefreshToken(refresh_token) + assert.Equal(t, nil, err) + assert.Equal(t, "new_access_token", token) +} + +func TestGoogleProviderValidateToken(t *testing.T) { + p := newGoogleProvider() + access_token := "access_token" + refresh_token := "refresh_token" + full_token := access_token + " " + refresh_token + + server := testGoogleValidateTokenBackend(p.Data(), access_token, "") + defer server.Close() + + ok, token := p.ValidateToken(full_token) + assert.Equal(t, true, ok) + assert.Equal(t, "", token) +} + +func TestGoogleProviderValidateTokenReturnRefreshedToken(t *testing.T) { + p := newGoogleProvider() + access_token := "access_token" + refresh_token := "refresh_token" + full_token := access_token + " " + refresh_token + + // Not setting a path, etc. will force an error. + validate_server := NewTestQueryBackend("", "", "") + defer validate_server.Close() + p.Data().ValidateUrl, _ = url.Parse(validate_server.URL) + + refresh_payload := `{"access_token":"new_access_token",` + + ` "expires_in":3920,` + + ` "token_type":"Bearer"}` + refresh_server := testGoogleRedeemRefreshTokenBackend( + p.Data(), refresh_token, refresh_payload) + defer refresh_server.Close() + + ok, token := p.ValidateToken(full_token) + assert.Equal(t, true, ok) + assert.Equal(t, "new_access_token "+refresh_token, token) +} diff --git a/providers/internal_util.go b/providers/internal_util.go index 4ccd0378a..0f22c2170 100644 --- a/providers/internal_util.go +++ b/providers/internal_util.go @@ -9,9 +9,9 @@ import ( "github.com/bitly/oauth2_proxy/api" ) -func validateToken(p Provider, access_token string, header http.Header) bool { +func validateToken(p Provider, access_token string, header http.Header) (ok bool) { if access_token == "" || p.Data().ValidateUrl == nil { - return false + return } endpoint := p.Data().ValidateUrl.String() if len(header) == 0 { @@ -21,14 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool { resp, err := api.RequestUnparsedResponse(endpoint, header) if err != nil { log.Printf("token validation request failed: %s", err) - return false + return } body, _ := ioutil.ReadAll(resp.Body) resp.Body.Close() if resp.StatusCode == 200 { - return true + ok = true + return } log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body) - return false + return } diff --git a/providers/internal_util_test.go b/providers/internal_util_test.go index 36a1d370f..7d58b4dce 100644 --- a/providers/internal_util_test.go +++ b/providers/internal_util_test.go @@ -18,8 +18,8 @@ func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token s // Note that we're testing the internal validateToken() used to implement // several Provider's ValidateToken() implementations -func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool { - return false +func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) (ok bool, new_token string) { + return } type ValidateTokenTest struct { diff --git a/providers/linkedin.go b/providers/linkedin.go index 6249ec480..e3a3c5eb7 100644 --- a/providers/linkedin.go +++ b/providers/linkedin.go @@ -74,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st return email, nil } -func (p *LinkedInProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, getLinkedInHeader(access_token)) +func (p *LinkedInProvider) ValidateToken(access_token string) (bool, string) { + return validateToken(p, access_token, getLinkedInHeader(access_token)), "" } diff --git a/providers/myusa.go b/providers/myusa.go index 707263942..2077c7dc2 100644 --- a/providers/myusa.go +++ b/providers/myusa.go @@ -57,6 +57,6 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin return json.Get("email").String() } -func (p *MyUsaProvider) ValidateToken(access_token string) bool { - return validateToken(p, access_token, nil) +func (p *MyUsaProvider) ValidateToken(access_token string) (bool, string) { + return validateToken(p, access_token, nil), "" } diff --git a/providers/myusa_test.go b/providers/myusa_test.go index 32e8520e8..243269616 100644 --- a/providers/myusa_test.go +++ b/providers/myusa_test.go @@ -2,7 +2,6 @@ package providers import ( "github.com/bmizerany/assert" - "net/http" "net/http/httptest" "net/url" "testing" @@ -34,17 +33,7 @@ func testMyUsaProvider(hostname string) *MyUsaProvider { func testMyUsaBackend(payload string) *httptest.Server { path := "/api/v1/profile" query := "access_token=imaginary_access_token" - - return httptest.NewServer(http.HandlerFunc( - func(w http.ResponseWriter, r *http.Request) { - url := r.URL - if url.Path != path || url.RawQuery != query { - w.WriteHeader(404) - } else { - w.WriteHeader(200) - w.Write([]byte(payload)) - } - })) + return NewTestQueryBackend(path, query, payload) } func TestMyUsaProviderDefaults(t *testing.T) { diff --git a/providers/providers.go b/providers/providers.go index b7e84eb04..c65517c0b 100644 --- a/providers/providers.go +++ b/providers/providers.go @@ -4,7 +4,7 @@ type Provider interface { Data() *ProviderData GetEmailAddress(body []byte, access_token string) (string, error) Redeem(string, string) ([]byte, string, error) - ValidateToken(access_token string) bool + ValidateToken(access_token string) (ok bool, new_token string) GetLoginURL(redirectURI, finalRedirect string) string } diff --git a/providers/test_util.go b/providers/test_util.go new file mode 100644 index 000000000..ff4d5c790 --- /dev/null +++ b/providers/test_util.go @@ -0,0 +1,44 @@ +package providers + +import ( + "log" + "net/http" + "net/http/httptest" + "net/url" + "reflect" +) + +func NewTestQueryBackend(path, query, payload string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + url := r.URL + if url.Path != path || url.RawQuery != query { + log.Printf("unexpected request:\n"+ + " expected: %s?%s\n"+ + " actual: %s?%s", path, query, + url.Path, url.RawQuery) + w.WriteHeader(404) + } else { + w.WriteHeader(200) + w.Write([]byte(payload)) + } + })) +} + +func NewTestPostBackend(path string, form url.Values, payload string) *httptest.Server { + return httptest.NewServer(http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + url := r.URL + r.ParseForm() + if url.Path != path || reflect.DeepEqual(r.Form, form) { + log.Printf("unexpected request:\n"+ + " expected: %s\n %v\n"+ + " actual: %s\n %v", path, form, + url.Path, r.Form) + w.WriteHeader(404) + } else { + w.WriteHeader(200) + w.Write([]byte(payload)) + } + })) +}