Skip to content
This repository was archived by the owner on Jan 24, 2019. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 7 additions & 4 deletions oauthproxy.go
Original file line number Diff line number Diff line change
Expand Up @@ -249,13 +249,16 @@ func (p *OauthProxy) ProcessCookie(rw http.ResponseWriter, req *http.Request) (e
log.Printf("refreshing %s old session for %s (refresh after %s)", time.Now().Sub(timestamp), email, p.CookieRefresh)
ok = p.Validator(email)
log.Printf("re-validating %s valid:%v", email, ok)
if ok {
ok = p.provider.ValidateToken(access_token)
log.Printf("re-validating access token. valid:%v", ok)
if !ok {
return
}
if ok {
if ok, new_token := p.provider.ValidateToken(access_token); ok {
if new_token != "" {
value = new_token
}
p.SetCookie(rw, req, value)
}
log.Printf("re-validating access token. valid:%v", ok)
}
}
return
Expand Down
4 changes: 2 additions & 2 deletions oauthproxy_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -98,8 +98,8 @@ func (tp *TestProvider) GetEmailAddress(body []byte, access_token string) (strin
return tp.EmailAddress, nil
}

func (tp *TestProvider) ValidateToken(access_token string) bool {
return tp.ValidToken
func (tp *TestProvider) ValidateToken(access_token string) (bool, string) {
return tp.ValidToken, ""
}

type PassAccessTokenTest struct {
Expand Down
4 changes: 2 additions & 2 deletions providers/github.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,6 +188,6 @@ func (p *GitHubProvider) GetEmailAddress(body []byte, access_token string) (stri
return "", nil
}

func (p *GitHubProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
func (p *GitHubProvider) ValidateToken(access_token string) (bool, string) {
return validateToken(p, access_token, nil), ""
}
26 changes: 23 additions & 3 deletions providers/google.go
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ import (
"errors"
"fmt"
"io/ioutil"
"log"
"net/http"
"net/url"
"strings"
Expand Down Expand Up @@ -81,8 +82,27 @@ func jwtDecodeSegment(seg string) ([]byte, error) {
return base64.URLEncoding.DecodeString(seg)
}

func (p *GoogleProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
func (p *GoogleProvider) ValidateToken(access_token string) (ok bool, new_token string) {
var orig_token string
var refresh_token string

if components := strings.Split(access_token, " "); len(components) != 2 {
return
} else {
orig_token, refresh_token = components[0], components[1]
}

if ok = validateToken(p, orig_token, nil); ok == true {
return
}
log.Printf("original token expired; redeeming refresh token")
if renewed, err := p.redeemRefreshToken(refresh_token); err == nil {
new_token = fmt.Sprintf("%s %s", renewed, refresh_token)
ok = true
} else {
log.Printf("redeeming refresh token failed: %v", err)
}
return
}

func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token string, err error) {
Expand Down Expand Up @@ -128,7 +148,7 @@ func (p *GoogleProvider) Redeem(redirectUrl, code string) (body []byte, token st
return
}

token, err = p.redeemRefreshToken(jsonResponse.RefreshToken)
token = fmt.Sprintf("%s %s", jsonResponse.AccessToken, jsonResponse.RefreshToken)
return
}

Expand Down
102 changes: 102 additions & 0 deletions providers/google_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/base64"
"encoding/json"
"github.com/bmizerany/assert"
"net/http/httptest"
"net/url"
"testing"
)
Expand All @@ -19,6 +20,44 @@ func newGoogleProvider() *GoogleProvider {
Scope: ""})
}

func testGoogleRedeemBackend(data *ProviderData, redirect_uri string, code string, payload string) (server *httptest.Server) {
path := "/oauth2/v3/token"
form := url.Values{
"redirect_uri": {redirect_uri},
"client_id": {"0"},
"client_secret": {""},
"code": {code},
"grant_type": {"authorization_code"},
}
server = NewTestPostBackend(path, form, payload)
data.RedeemUrl, _ = url.Parse(server.URL)
data.RedeemUrl.Path = path
return
}

func testGoogleRedeemRefreshTokenBackend(data *ProviderData, refresh_token string, payload string) (server *httptest.Server) {
path := "/oauth2/v3/token"
form := url.Values{
"client_id": {"0"},
"client_secret": {""},
"refresh_token": {refresh_token},
"grant_type": {"refresh_token"},
}
server = NewTestPostBackend(path, form, payload)
data.RedeemUrl, _ = url.Parse(server.URL)
data.RedeemUrl.Path = path
return
}

func testGoogleValidateTokenBackend(data *ProviderData, access_token string, payload string) (server *httptest.Server) {
path := "/oauth2/v1/tokeninfo"
query := "access_token=" + access_token
server = NewTestQueryBackend(path, query, payload)
data.ValidateUrl, _ = url.Parse(server.URL)
data.ValidateUrl.Path = path
return
}

func TestGoogleProviderDefaults(t *testing.T) {
p := newGoogleProvider()
assert.NotEqual(t, nil, p)
Expand Down Expand Up @@ -126,3 +165,66 @@ func TestGoogleProviderGetEmailAddressEmailMissing(t *testing.T) {
assert.Equal(t, "", email)
assert.NotEqual(t, nil, err)
}

func TestGoogleProviderReedeem(t *testing.T) {
p, redirect_uri, code := newGoogleProvider(), "/redirect", "my code"
payload := `{"access_token":"access_token",` +
` "expires_in":3920,` +
` "token_type":"Bearer",` +
` "refresh_token":"refresh_token"}`
server := testGoogleRedeemBackend(p.Data(), redirect_uri, code, payload)
defer server.Close()
response, token, err := p.Redeem(redirect_uri, code)
assert.Equal(t, nil, err)
assert.Equal(t, payload, string(response))
assert.Equal(t, "access_token refresh_token", token)
}

func TestGoogleProviderReedeemRefreshToken(t *testing.T) {
p, refresh_token := newGoogleProvider(), "refresh_token"
payload := `{"access_token":"new_access_token",` +
` "expires_in":3920,` +
` "token_type":"Bearer"}`
server := testGoogleRedeemRefreshTokenBackend(p.Data(), refresh_token, payload)
defer server.Close()
token, err := p.redeemRefreshToken(refresh_token)
assert.Equal(t, nil, err)
assert.Equal(t, "new_access_token", token)
}

func TestGoogleProviderValidateToken(t *testing.T) {
p := newGoogleProvider()
access_token := "access_token"
refresh_token := "refresh_token"
full_token := access_token + " " + refresh_token

server := testGoogleValidateTokenBackend(p.Data(), access_token, "")
defer server.Close()

ok, token := p.ValidateToken(full_token)
assert.Equal(t, true, ok)
assert.Equal(t, "", token)
}

func TestGoogleProviderValidateTokenReturnRefreshedToken(t *testing.T) {
p := newGoogleProvider()
access_token := "access_token"
refresh_token := "refresh_token"
full_token := access_token + " " + refresh_token

// Not setting a path, etc. will force an error.
validate_server := NewTestQueryBackend("", "", "")
defer validate_server.Close()
p.Data().ValidateUrl, _ = url.Parse(validate_server.URL)

refresh_payload := `{"access_token":"new_access_token",` +
` "expires_in":3920,` +
` "token_type":"Bearer"}`
refresh_server := testGoogleRedeemRefreshTokenBackend(
p.Data(), refresh_token, refresh_payload)
defer refresh_server.Close()

ok, token := p.ValidateToken(full_token)
assert.Equal(t, true, ok)
assert.Equal(t, "new_access_token "+refresh_token, token)
}
11 changes: 6 additions & 5 deletions providers/internal_util.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,9 @@ import (
"github.com/bitly/oauth2_proxy/api"
)

func validateToken(p Provider, access_token string, header http.Header) bool {
func validateToken(p Provider, access_token string, header http.Header) (ok bool) {
if access_token == "" || p.Data().ValidateUrl == nil {
return false
return
}
endpoint := p.Data().ValidateUrl.String()
if len(header) == 0 {
Expand All @@ -21,14 +21,15 @@ func validateToken(p Provider, access_token string, header http.Header) bool {
resp, err := api.RequestUnparsedResponse(endpoint, header)
if err != nil {
log.Printf("token validation request failed: %s", err)
return false
return
}

body, _ := ioutil.ReadAll(resp.Body)
resp.Body.Close()
if resp.StatusCode == 200 {
return true
ok = true
return
}
log.Printf("token validation request failed: status %d - %s", resp.StatusCode, body)
return false
return
}
4 changes: 2 additions & 2 deletions providers/internal_util_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ func (tp *ValidateTokenTestProvider) GetEmailAddress(body []byte, access_token s

// Note that we're testing the internal validateToken() used to implement
// several Provider's ValidateToken() implementations
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) bool {
return false
func (tp *ValidateTokenTestProvider) ValidateToken(access_token string) (ok bool, new_token string) {
return
}

type ValidateTokenTest struct {
Expand Down
4 changes: 2 additions & 2 deletions providers/linkedin.go
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,6 @@ func (p *LinkedInProvider) GetEmailAddress(body []byte, access_token string) (st
return email, nil
}

func (p *LinkedInProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, getLinkedInHeader(access_token))
func (p *LinkedInProvider) ValidateToken(access_token string) (bool, string) {
return validateToken(p, access_token, getLinkedInHeader(access_token)), ""
}
4 changes: 2 additions & 2 deletions providers/myusa.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,6 @@ func (p *MyUsaProvider) GetEmailAddress(body []byte, access_token string) (strin
return json.Get("email").String()
}

func (p *MyUsaProvider) ValidateToken(access_token string) bool {
return validateToken(p, access_token, nil)
func (p *MyUsaProvider) ValidateToken(access_token string) (bool, string) {
return validateToken(p, access_token, nil), ""
}
13 changes: 1 addition & 12 deletions providers/myusa_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ package providers

import (
"github.com/bmizerany/assert"
"net/http"
"net/http/httptest"
"net/url"
"testing"
Expand Down Expand Up @@ -34,17 +33,7 @@ func testMyUsaProvider(hostname string) *MyUsaProvider {
func testMyUsaBackend(payload string) *httptest.Server {
path := "/api/v1/profile"
query := "access_token=imaginary_access_token"

return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path || url.RawQuery != query {
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
return NewTestQueryBackend(path, query, payload)
}

func TestMyUsaProviderDefaults(t *testing.T) {
Expand Down
2 changes: 1 addition & 1 deletion providers/providers.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ type Provider interface {
Data() *ProviderData
GetEmailAddress(body []byte, access_token string) (string, error)
Redeem(string, string) ([]byte, string, error)
ValidateToken(access_token string) bool
ValidateToken(access_token string) (ok bool, new_token string)
GetLoginURL(redirectURI, finalRedirect string) string
}

Expand Down
44 changes: 44 additions & 0 deletions providers/test_util.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
package providers

import (
"log"
"net/http"
"net/http/httptest"
"net/url"
"reflect"
)

func NewTestQueryBackend(path, query, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
if url.Path != path || url.RawQuery != query {
log.Printf("unexpected request:\n"+
" expected: %s?%s\n"+
" actual: %s?%s", path, query,
url.Path, url.RawQuery)
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}

func NewTestPostBackend(path string, form url.Values, payload string) *httptest.Server {
return httptest.NewServer(http.HandlerFunc(
func(w http.ResponseWriter, r *http.Request) {
url := r.URL
r.ParseForm()
if url.Path != path || reflect.DeepEqual(r.Form, form) {
log.Printf("unexpected request:\n"+
" expected: %s\n %v\n"+
" actual: %s\n %v", path, form,
url.Path, r.Form)
w.WriteHeader(404)
} else {
w.WriteHeader(200)
w.Write([]byte(payload))
}
}))
}