Skip to content

Commit e98fcc0

Browse files
committed
add context support
1 parent 2b95465 commit e98fcc0

16 files changed

+148
-116
lines changed

generate.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package oauth2
22

33
import (
4+
"context"
45
"net/http"
56
"time"
67
)
@@ -17,11 +18,11 @@ type (
1718

1819
// AuthorizeGenerate generate the authorization code interface
1920
AuthorizeGenerate interface {
20-
Token(data *GenerateBasic) (code string, err error)
21+
Token(ctx context.Context, data *GenerateBasic) (code string, err error)
2122
}
2223

2324
// AccessGenerate generate the access and refresh tokens interface
2425
AccessGenerate interface {
25-
Token(data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
26+
Token(ctx context.Context, data *GenerateBasic, isGenRefresh bool) (access, refresh string, err error)
2627
}
2728
)

generates/access.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package generates
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/base64"
67
"strconv"
78
"strings"
@@ -20,7 +21,7 @@ type AccessGenerate struct {
2021
}
2122

2223
// Token based on the UUID generated token
23-
func (ag *AccessGenerate) Token(data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
24+
func (ag *AccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
2425
buf := bytes.NewBufferString(data.Client.GetID())
2526
buf.WriteString(data.UserID)
2627
buf.WriteString(strconv.FormatInt(data.CreateAt.UnixNano(), 10))

generates/access_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package generates_test
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

@@ -22,7 +23,7 @@ func TestAccess(t *testing.T) {
2223
CreateAt: time.Now(),
2324
}
2425
gen := generates.NewAccessGenerate()
25-
access, refresh, err := gen.Token(data, true)
26+
access, refresh, err := gen.Token(context.Background(), data, true)
2627
So(err, ShouldBeNil)
2728
So(access, ShouldNotBeEmpty)
2829
So(refresh, ShouldNotBeEmpty)

generates/authorize.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package generates
22

33
import (
44
"bytes"
5+
"context"
56
"encoding/base64"
67
"strings"
78

@@ -18,7 +19,7 @@ func NewAuthorizeGenerate() *AuthorizeGenerate {
1819
type AuthorizeGenerate struct{}
1920

2021
// Token based on the UUID generated token
21-
func (ag *AuthorizeGenerate) Token(data *oauth2.GenerateBasic) (string, error) {
22+
func (ag *AuthorizeGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic) (string, error) {
2223
buf := bytes.NewBufferString(data.Client.GetID())
2324
buf.WriteString(data.UserID)
2425
token := uuid.NewMD5(uuid.Must(uuid.NewRandom()), buf.Bytes())

generates/authorize_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package generates_test
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

@@ -22,7 +23,7 @@ func TestAuthorize(t *testing.T) {
2223
CreateAt: time.Now(),
2324
}
2425
gen := generates.NewAuthorizeGenerate()
25-
code, err := gen.Token(data)
26+
code, err := gen.Token(context.Background(), data)
2627
So(err, ShouldBeNil)
2728
So(code, ShouldNotBeEmpty)
2829
Println("\nAuthorize Code:" + code)

generates/jwt_access.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package generates
22

33
import (
4+
"context"
45
"encoding/base64"
56
"strings"
67
"time"
@@ -41,7 +42,7 @@ type JWTAccessGenerate struct {
4142
}
4243

4344
// Token based on the UUID generated token
44-
func (a *JWTAccessGenerate) Token(data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
45+
func (a *JWTAccessGenerate) Token(ctx context.Context, data *oauth2.GenerateBasic, isGenRefresh bool) (string, string, error) {
4546
claims := &JWTAccessClaims{
4647
StandardClaims: jwt.StandardClaims{
4748
Audience: data.Client.GetID(),

generates/jwt_access_test.go

+2-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package generates_test
22

33
import (
4+
"context"
45
"fmt"
56
"testing"
67
"time"
@@ -28,7 +29,7 @@ func TestJWTAccess(t *testing.T) {
2829
}
2930

3031
gen := generates.NewJWTAccessGenerate([]byte("00000000"), jwt.SigningMethodHS512)
31-
access, refresh, err := gen.Token(data, true)
32+
access, refresh, err := gen.Token(context.Background(), data, true)
3233
So(err, ShouldBeNil)
3334
So(access, ShouldNotBeEmpty)
3435
So(refresh, ShouldNotBeEmpty)

manage.go

+9-8
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package oauth2
22

33
import (
4+
"context"
45
"net/http"
56
"time"
67
)
@@ -21,26 +22,26 @@ type TokenGenerateRequest struct {
2122
// Manager authorization management interface
2223
type Manager interface {
2324
// get the client information
24-
GetClient(clientID string) (cli ClientInfo, err error)
25+
GetClient(ctx context.Context, clientID string) (cli ClientInfo, err error)
2526

2627
// generate the authorization token(code)
27-
GenerateAuthToken(rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error)
28+
GenerateAuthToken(ctx context.Context, rt ResponseType, tgr *TokenGenerateRequest) (authToken TokenInfo, err error)
2829

2930
// generate the access token
30-
GenerateAccessToken(rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
31+
GenerateAccessToken(ctx context.Context, rt GrantType, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
3132

3233
// refreshing an access token
33-
RefreshAccessToken(tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
34+
RefreshAccessToken(ctx context.Context, tgr *TokenGenerateRequest) (accessToken TokenInfo, err error)
3435

3536
// use the access token to delete the token information
36-
RemoveAccessToken(access string) (err error)
37+
RemoveAccessToken(ctx context.Context, access string) (err error)
3738

3839
// use the refresh token to delete the token information
39-
RemoveRefreshToken(refresh string) (err error)
40+
RemoveRefreshToken(ctx context.Context, refresh string) (err error)
4041

4142
// according to the access token for corresponding token information
42-
LoadAccessToken(access string) (ti TokenInfo, err error)
43+
LoadAccessToken(ctx context.Context, access string) (ti TokenInfo, err error)
4344

4445
// according to the refresh token for corresponding token information
45-
LoadRefreshToken(refresh string) (ti TokenInfo, err error)
46+
LoadRefreshToken(ctx context.Context, refresh string) (ti TokenInfo, err error)
4647
}

manage/manage_test.go

+29-23
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package manage_test
22

33
import (
4+
"context"
45
"testing"
56
"time"
67

@@ -15,6 +16,7 @@ import (
1516
func TestManager(t *testing.T) {
1617
Convey("Manager test", t, func() {
1718
manager := manage.NewDefaultManager()
19+
ctx := context.Background()
1820

1921
manager.MustTokenStorage(store.NewMemoryTokenStore())
2022

@@ -34,7 +36,7 @@ func TestManager(t *testing.T) {
3436
}
3537

3638
Convey("GetClient test", func() {
37-
cli, err := manager.GetClient("1")
39+
cli, err := manager.GetClient(ctx, "1")
3840
So(err, ShouldBeNil)
3941
So(cli.GetSecret(), ShouldEqual, "11")
4042
})
@@ -55,7 +57,8 @@ func TestManager(t *testing.T) {
5557
}
5658

5759
func testManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
58-
cti, err := manager.GenerateAuthToken(oauth2.Code, tgr)
60+
ctx := context.Background()
61+
cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
5962
So(err, ShouldBeNil)
6063

6164
code := cti.GetCode()
@@ -67,58 +70,59 @@ func testManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
6770
RedirectURI: tgr.RedirectURI,
6871
Code: code,
6972
}
70-
ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams)
73+
ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
7174
So(err, ShouldBeNil)
7275

7376
accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
7477
So(accessToken, ShouldNotBeEmpty)
7578
So(refreshToken, ShouldNotBeEmpty)
7679

77-
ainfo, err := manager.LoadAccessToken(accessToken)
80+
ainfo, err := manager.LoadAccessToken(ctx, accessToken)
7881
So(err, ShouldBeNil)
7982
So(ainfo.GetClientID(), ShouldEqual, atParams.ClientID)
8083

81-
arinfo, err := manager.LoadRefreshToken(accessToken)
84+
arinfo, err := manager.LoadRefreshToken(ctx, accessToken)
8285
So(err, ShouldNotBeNil)
8386
So(arinfo, ShouldBeNil)
8487

85-
rainfo, err := manager.LoadAccessToken(refreshToken)
88+
rainfo, err := manager.LoadAccessToken(ctx, refreshToken)
8689
So(err, ShouldNotBeNil)
8790
So(rainfo, ShouldBeNil)
8891

89-
rinfo, err := manager.LoadRefreshToken(refreshToken)
92+
rinfo, err := manager.LoadRefreshToken(ctx, refreshToken)
9093
So(err, ShouldBeNil)
9194
So(rinfo.GetClientID(), ShouldEqual, atParams.ClientID)
9295

9396
atParams.Refresh = refreshToken
9497
atParams.Scope = "owner"
95-
rti, err := manager.RefreshAccessToken(atParams)
98+
rti, err := manager.RefreshAccessToken(ctx, atParams)
9699
So(err, ShouldBeNil)
97100

98101
refreshAT := rti.GetAccess()
99102
So(refreshAT, ShouldNotBeEmpty)
100103

101-
_, err = manager.LoadAccessToken(accessToken)
104+
_, err = manager.LoadAccessToken(ctx, accessToken)
102105
So(err, ShouldNotBeNil)
103106

104-
refreshAInfo, err := manager.LoadAccessToken(refreshAT)
107+
refreshAInfo, err := manager.LoadAccessToken(ctx, refreshAT)
105108
So(err, ShouldBeNil)
106109
So(refreshAInfo.GetScope(), ShouldEqual, "owner")
107110

108-
err = manager.RemoveAccessToken(refreshAT)
111+
err = manager.RemoveAccessToken(ctx, refreshAT)
109112
So(err, ShouldBeNil)
110113

111-
_, err = manager.LoadAccessToken(refreshAT)
114+
_, err = manager.LoadAccessToken(ctx, refreshAT)
112115
So(err, ShouldNotBeNil)
113116

114-
err = manager.RemoveRefreshToken(refreshToken)
117+
err = manager.RemoveRefreshToken(ctx, refreshToken)
115118
So(err, ShouldBeNil)
116119

117-
_, err = manager.LoadRefreshToken(refreshToken)
120+
_, err = manager.LoadRefreshToken(ctx, refreshToken)
118121
So(err, ShouldNotBeNil)
119122
}
120123

121124
func testZeroAccessExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
125+
ctx := context.Background()
122126
config := manage.Config{
123127
AccessTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration
124128
IsGenerateRefresh: true,
@@ -127,7 +131,7 @@ func testZeroAccessExpirationManager(tgr *oauth2.TokenGenerateRequest, manager o
127131
So(ok, ShouldBeTrue)
128132
m.SetAuthorizeCodeTokenCfg(&config)
129133

130-
cti, err := manager.GenerateAuthToken(oauth2.Code, tgr)
134+
cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
131135
So(err, ShouldBeNil)
132136

133137
code := cti.GetCode()
@@ -139,29 +143,30 @@ func testZeroAccessExpirationManager(tgr *oauth2.TokenGenerateRequest, manager o
139143
RedirectURI: tgr.RedirectURI,
140144
Code: code,
141145
}
142-
ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams)
146+
ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
143147
So(err, ShouldBeNil)
144148

145149
accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
146150
So(accessToken, ShouldNotBeEmpty)
147151
So(refreshToken, ShouldNotBeEmpty)
148152

149-
tokenInfo, err := manager.LoadAccessToken(accessToken)
153+
tokenInfo, err := manager.LoadAccessToken(ctx, accessToken)
150154
So(err, ShouldBeNil)
151155
So(tokenInfo, ShouldNotBeNil)
152156
So(tokenInfo.GetAccess(), ShouldEqual, accessToken)
153157
So(tokenInfo.GetAccessExpiresIn(), ShouldEqual, 0)
154158
}
155159

156160
func testCannotRequestZeroExpirationAccessTokens(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
161+
ctx := context.Background()
157162
config := manage.Config{
158163
AccessTokenExp: time.Hour * 5,
159164
}
160165
m, ok := manager.(*manage.Manager)
161166
So(ok, ShouldBeTrue)
162167
m.SetAuthorizeCodeTokenCfg(&config)
163168

164-
cti, err := manager.GenerateAuthToken(oauth2.Code, tgr)
169+
cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
165170
So(err, ShouldBeNil)
166171

167172
code := cti.GetCode()
@@ -174,7 +179,7 @@ func testCannotRequestZeroExpirationAccessTokens(tgr *oauth2.TokenGenerateReques
174179
AccessTokenExp: 0, // requesting token without expiration
175180
Code: code,
176181
}
177-
ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams)
182+
ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
178183
So(err, ShouldBeNil)
179184

180185
accessToken := ati.GetAccess()
@@ -183,6 +188,7 @@ func testCannotRequestZeroExpirationAccessTokens(tgr *oauth2.TokenGenerateReques
183188
}
184189

185190
func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager oauth2.Manager) {
191+
ctx := context.Background()
186192
config := manage.Config{
187193
RefreshTokenExp: 0, // Set explicitly as we're testing 0 (no) expiration
188194
IsGenerateRefresh: true,
@@ -191,7 +197,7 @@ func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager
191197
So(ok, ShouldBeTrue)
192198
m.SetAuthorizeCodeTokenCfg(&config)
193199

194-
cti, err := manager.GenerateAuthToken(oauth2.Code, tgr)
200+
cti, err := manager.GenerateAuthToken(ctx, oauth2.Code, tgr)
195201
So(err, ShouldBeNil)
196202

197203
code := cti.GetCode()
@@ -204,21 +210,21 @@ func testZeroRefreshExpirationManager(tgr *oauth2.TokenGenerateRequest, manager
204210
AccessTokenExp: time.Hour,
205211
Code: code,
206212
}
207-
ati, err := manager.GenerateAccessToken(oauth2.AuthorizationCode, atParams)
213+
ati, err := manager.GenerateAccessToken(ctx, oauth2.AuthorizationCode, atParams)
208214
So(err, ShouldBeNil)
209215

210216
accessToken, refreshToken := ati.GetAccess(), ati.GetRefresh()
211217
So(accessToken, ShouldNotBeEmpty)
212218
So(refreshToken, ShouldNotBeEmpty)
213219

214-
tokenInfo, err := manager.LoadRefreshToken(refreshToken)
220+
tokenInfo, err := manager.LoadRefreshToken(ctx, refreshToken)
215221
So(err, ShouldBeNil)
216222
So(tokenInfo, ShouldNotBeNil)
217223
So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)
218224
So(tokenInfo.GetRefreshExpiresIn(), ShouldEqual, 0)
219225

220226
// LoadAccessToken also checks refresh expiry
221-
tokenInfo, err = manager.LoadAccessToken(accessToken)
227+
tokenInfo, err = manager.LoadAccessToken(ctx, accessToken)
222228
So(err, ShouldBeNil)
223229
So(tokenInfo, ShouldNotBeNil)
224230
So(tokenInfo.GetRefresh(), ShouldEqual, refreshToken)

0 commit comments

Comments
 (0)