Skip to content

Commit e3fb0fb

Browse files
hickfordgopherbot
authored andcommitted
oauth2: support device flow
Device Authorization Grant following RFC 8628 https://datatracker.ietf.org/doc/html/rfc8628 Tested with GitHub Fixes golang#418 Fixes golang/go#58126 Co-authored-by: cmP <[email protected]> Change-Id: Id588867110c6a5289bf1026da5d7ead88f9c7d14 GitHub-Last-Rev: 9a126d7 GitHub-Pull-Request: golang#609 Reviewed-on: https://go-review.googlesource.com/c/oauth2/+/450155 Commit-Queue: Bryan Mills <[email protected]> TryBot-Result: Gopher Robot <[email protected]> Reviewed-by: Than McIntosh <[email protected]> Auto-Submit: Bryan Mills <[email protected]> Run-TryBot: Matt Hickford <[email protected]> Reviewed-by: Bryan Mills <[email protected]> Run-TryBot: Bryan Mills <[email protected]>
1 parent 0708528 commit e3fb0fb

File tree

6 files changed

+295
-9
lines changed

6 files changed

+295
-9
lines changed

deviceauth.go

+188
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,188 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"io"
8+
"net/http"
9+
"net/url"
10+
"strings"
11+
"time"
12+
13+
"golang.org/x/oauth2/internal"
14+
)
15+
16+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
17+
const (
18+
errAuthorizationPending = "authorization_pending"
19+
errSlowDown = "slow_down"
20+
errAccessDenied = "access_denied"
21+
errExpiredToken = "expired_token"
22+
)
23+
24+
// DeviceAuthResponse describes a successful RFC 8628 Device Authorization Response
25+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
26+
type DeviceAuthResponse struct {
27+
// DeviceCode
28+
DeviceCode string `json:"device_code"`
29+
// UserCode is the code the user should enter at the verification uri
30+
UserCode string `json:"user_code"`
31+
// VerificationURI is where user should enter the user code
32+
VerificationURI string `json:"verification_uri"`
33+
// VerificationURIComplete (if populated) includes the user code in the verification URI. This is typically shown to the user in non-textual form, such as a QR code.
34+
VerificationURIComplete string `json:"verification_uri_complete,omitempty"`
35+
// Expiry is when the device code and user code expire
36+
Expiry time.Time `json:"expires_in,omitempty"`
37+
// Interval is the duration in seconds that Poll should wait between requests
38+
Interval int64 `json:"interval,omitempty"`
39+
}
40+
41+
func (d DeviceAuthResponse) MarshalJSON() ([]byte, error) {
42+
type Alias DeviceAuthResponse
43+
var expiresIn int64
44+
if !d.Expiry.IsZero() {
45+
expiresIn = int64(time.Until(d.Expiry).Seconds())
46+
}
47+
return json.Marshal(&struct {
48+
ExpiresIn int64 `json:"expires_in,omitempty"`
49+
*Alias
50+
}{
51+
ExpiresIn: expiresIn,
52+
Alias: (*Alias)(&d),
53+
})
54+
55+
}
56+
57+
func (c *DeviceAuthResponse) UnmarshalJSON(data []byte) error {
58+
type Alias DeviceAuthResponse
59+
aux := &struct {
60+
ExpiresIn int64 `json:"expires_in"`
61+
*Alias
62+
}{
63+
Alias: (*Alias)(c),
64+
}
65+
if err := json.Unmarshal(data, &aux); err != nil {
66+
return err
67+
}
68+
if aux.ExpiresIn != 0 {
69+
c.Expiry = time.Now().UTC().Add(time.Second * time.Duration(aux.ExpiresIn))
70+
}
71+
return nil
72+
}
73+
74+
// DeviceAuth returns a device auth struct which contains a device code
75+
// and authorization information provided for users to enter on another device.
76+
func (c *Config) DeviceAuth(ctx context.Context, opts ...AuthCodeOption) (*DeviceAuthResponse, error) {
77+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.1
78+
v := url.Values{
79+
"client_id": {c.ClientID},
80+
}
81+
if len(c.Scopes) > 0 {
82+
v.Set("scope", strings.Join(c.Scopes, " "))
83+
}
84+
for _, opt := range opts {
85+
opt.setValue(v)
86+
}
87+
return retrieveDeviceAuth(ctx, c, v)
88+
}
89+
90+
func retrieveDeviceAuth(ctx context.Context, c *Config, v url.Values) (*DeviceAuthResponse, error) {
91+
req, err := http.NewRequest("POST", c.Endpoint.DeviceAuthURL, strings.NewReader(v.Encode()))
92+
if err != nil {
93+
return nil, err
94+
}
95+
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
96+
req.Header.Set("Accept", "application/json")
97+
98+
t := time.Now()
99+
r, err := internal.ContextClient(ctx).Do(req)
100+
if err != nil {
101+
return nil, err
102+
}
103+
104+
body, err := io.ReadAll(io.LimitReader(r.Body, 1<<20))
105+
if err != nil {
106+
return nil, fmt.Errorf("oauth2: cannot auth device: %v", err)
107+
}
108+
if code := r.StatusCode; code < 200 || code > 299 {
109+
return nil, &RetrieveError{
110+
Response: r,
111+
Body: body,
112+
}
113+
}
114+
115+
da := &DeviceAuthResponse{}
116+
err = json.Unmarshal(body, &da)
117+
if err != nil {
118+
return nil, fmt.Errorf("unmarshal %s", err)
119+
}
120+
121+
if !da.Expiry.IsZero() {
122+
// Make a small adjustment to account for time taken by the request
123+
da.Expiry = da.Expiry.Add(-time.Since(t))
124+
}
125+
126+
return da, nil
127+
}
128+
129+
// DeviceAccessToken polls the server to exchange a device code for a token.
130+
func (c *Config) DeviceAccessToken(ctx context.Context, da *DeviceAuthResponse, opts ...AuthCodeOption) (*Token, error) {
131+
if !da.Expiry.IsZero() {
132+
var cancel context.CancelFunc
133+
ctx, cancel = context.WithDeadline(ctx, da.Expiry)
134+
defer cancel()
135+
}
136+
137+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.4
138+
v := url.Values{
139+
"client_id": {c.ClientID},
140+
"grant_type": {"urn:ietf:params:oauth:grant-type:device_code"},
141+
"device_code": {da.DeviceCode},
142+
}
143+
if len(c.Scopes) > 0 {
144+
v.Set("scope", strings.Join(c.Scopes, " "))
145+
}
146+
for _, opt := range opts {
147+
opt.setValue(v)
148+
}
149+
150+
// "If no value is provided, clients MUST use 5 as the default."
151+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.2
152+
interval := da.Interval
153+
if interval == 0 {
154+
interval = 5
155+
}
156+
157+
ticker := time.NewTicker(time.Duration(interval) * time.Second)
158+
defer ticker.Stop()
159+
for {
160+
select {
161+
case <-ctx.Done():
162+
return nil, ctx.Err()
163+
case <-ticker.C:
164+
tok, err := retrieveToken(ctx, c, v)
165+
if err == nil {
166+
return tok, nil
167+
}
168+
169+
e, ok := err.(*RetrieveError)
170+
if !ok {
171+
return nil, err
172+
}
173+
switch e.ErrorCode {
174+
case errSlowDown:
175+
// https://datatracker.ietf.org/doc/html/rfc8628#section-3.5
176+
// "the interval MUST be increased by 5 seconds for this and all subsequent requests"
177+
interval += 5
178+
ticker.Reset(time.Duration(interval) * time.Second)
179+
case errAuthorizationPending:
180+
// Do nothing.
181+
case errAccessDenied, errExpiredToken:
182+
fallthrough
183+
default:
184+
return tok, err
185+
}
186+
}
187+
}
188+
}

deviceauth_test.go

+97
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,97 @@
1+
package oauth2
2+
3+
import (
4+
"context"
5+
"encoding/json"
6+
"fmt"
7+
"strings"
8+
"testing"
9+
"time"
10+
11+
"github.com/google/go-cmp/cmp"
12+
"github.com/google/go-cmp/cmp/cmpopts"
13+
)
14+
15+
func TestDeviceAuthResponseMarshalJson(t *testing.T) {
16+
tests := []struct {
17+
name string
18+
response DeviceAuthResponse
19+
want string
20+
}{
21+
{
22+
name: "empty",
23+
response: DeviceAuthResponse{},
24+
want: `{"device_code":"","user_code":"","verification_uri":""}`,
25+
},
26+
{
27+
name: "soon",
28+
response: DeviceAuthResponse{
29+
Expiry: time.Now().Add(100*time.Second + 999*time.Millisecond),
30+
},
31+
want: `{"expires_in":100,"device_code":"","user_code":"","verification_uri":""}`,
32+
},
33+
}
34+
for _, tc := range tests {
35+
t.Run(tc.name, func(t *testing.T) {
36+
begin := time.Now()
37+
gotBytes, err := json.Marshal(tc.response)
38+
if err != nil {
39+
t.Fatal(err)
40+
}
41+
if strings.Contains(tc.want, "expires_in") && time.Since(begin) > 999*time.Millisecond {
42+
t.Skip("test ran too slowly to compare `expires_in`")
43+
}
44+
got := string(gotBytes)
45+
if got != tc.want {
46+
t.Errorf("want=%s, got=%s", tc.want, got)
47+
}
48+
})
49+
}
50+
}
51+
52+
func TestDeviceAuthResponseUnmarshalJson(t *testing.T) {
53+
tests := []struct {
54+
name string
55+
data string
56+
want DeviceAuthResponse
57+
}{
58+
{
59+
name: "empty",
60+
data: `{}`,
61+
want: DeviceAuthResponse{},
62+
},
63+
{
64+
name: "soon",
65+
data: `{"expires_in":100}`,
66+
want: DeviceAuthResponse{Expiry: time.Now().UTC().Add(100 * time.Second)},
67+
},
68+
}
69+
for _, tc := range tests {
70+
t.Run(tc.name, func(t *testing.T) {
71+
begin := time.Now()
72+
got := DeviceAuthResponse{}
73+
err := json.Unmarshal([]byte(tc.data), &got)
74+
if err != nil {
75+
t.Fatal(err)
76+
}
77+
if !cmp.Equal(got, tc.want, cmpopts.IgnoreUnexported(DeviceAuthResponse{}), cmpopts.EquateApproxTime(time.Second+time.Since(begin))) {
78+
t.Errorf("want=%#v, got=%#v", tc.want, got)
79+
}
80+
})
81+
}
82+
}
83+
84+
func ExampleConfig_DeviceAuth() {
85+
var config Config
86+
ctx := context.Background()
87+
response, err := config.DeviceAuth(ctx)
88+
if err != nil {
89+
panic(err)
90+
}
91+
fmt.Printf("please enter code %s at %s\n", response.UserCode, response.VerificationURI)
92+
token, err := config.DeviceAccessToken(ctx, response)
93+
if err != nil {
94+
panic(err)
95+
}
96+
fmt.Println(token)
97+
}

endpoints/endpoints.go

+4-2
Original file line numberDiff line numberDiff line change
@@ -55,8 +55,9 @@ var Fitbit = oauth2.Endpoint{
5555

5656
// GitHub is the endpoint for Github.
5757
var GitHub = oauth2.Endpoint{
58-
AuthURL: "https://github.com/login/oauth/authorize",
59-
TokenURL: "https://github.com/login/oauth/access_token",
58+
AuthURL: "https://github.com/login/oauth/authorize",
59+
TokenURL: "https://github.com/login/oauth/access_token",
60+
DeviceAuthURL: "https://github.com/login/device/code",
6061
}
6162

6263
// GitLab is the endpoint for GitLab.
@@ -69,6 +70,7 @@ var GitLab = oauth2.Endpoint{
6970
var Google = oauth2.Endpoint{
7071
AuthURL: "https://accounts.google.com/o/oauth2/auth",
7172
TokenURL: "https://oauth2.googleapis.com/token",
73+
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
7274
}
7375

7476
// Heroku is the endpoint for Heroku.

github/github.go

+2-5
Original file line numberDiff line numberDiff line change
@@ -6,11 +6,8 @@
66
package github // import "golang.org/x/oauth2/github"
77

88
import (
9-
"golang.org/x/oauth2"
9+
"golang.org/x/oauth2/endpoints"
1010
)
1111

1212
// Endpoint is Github's OAuth 2.0 endpoint.
13-
var Endpoint = oauth2.Endpoint{
14-
AuthURL: "https://github.com/login/oauth/authorize",
15-
TokenURL: "https://github.com/login/oauth/access_token",
16-
}
13+
var Endpoint = endpoints.GitHub

google/google.go

+1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ import (
2323
var Endpoint = oauth2.Endpoint{
2424
AuthURL: "https://accounts.google.com/o/oauth2/auth",
2525
TokenURL: "https://oauth2.googleapis.com/token",
26+
DeviceAuthURL: "https://oauth2.googleapis.com/device/code",
2627
AuthStyle: oauth2.AuthStyleInParams,
2728
}
2829

oauth2.go

+3-2
Original file line numberDiff line numberDiff line change
@@ -75,8 +75,9 @@ type TokenSource interface {
7575
// Endpoint represents an OAuth 2.0 provider's authorization and token
7676
// endpoint URLs.
7777
type Endpoint struct {
78-
AuthURL string
79-
TokenURL string
78+
AuthURL string
79+
DeviceAuthURL string
80+
TokenURL string
8081

8182
// AuthStyle optionally specifies how the endpoint wants the
8283
// client ID & client secret sent. The zero value means to

0 commit comments

Comments
 (0)