Skip to content

Commit 9721fd8

Browse files
committed
finished tests
1 parent 647a3b8 commit 9721fd8

File tree

6 files changed

+276
-22
lines changed

6 files changed

+276
-22
lines changed

internal/mock_api/authentication/authentication.go

+2-9
Original file line numberDiff line numberDiff line change
@@ -31,20 +31,13 @@ func AuthenticationMiddleware(next mock_api.MockEndpoint) http.Handler {
3131
w.WriteHeader(http.StatusMethodNotAllowed)
3232
return
3333
}
34-
if len(r.URL.Query()["skip_auth"]) > 0 && r.URL.Query()["skip_auth"][0] == "true" {
35-
fakeAuth := UserAuthentication{}
36-
r = r.WithContext(context.WithValue(r.Context(), "auth", fakeAuth))
37-
next.ServeHTTP(w, r)
38-
log.Printf("Skipping auth...")
39-
return
40-
}
4134

4235
clientID := r.Header.Get("Client-ID")
4336
bearerToken := r.Header.Get("Authorization")
4437
unauthroizedError := mock_errors.GetErrorBytes(http.StatusUnauthorized, errors.New("Unauthorized"), "Missing Client ID or OAuth token")
4538
if clientID == "" || bearerToken == "" || len(bearerToken) < 7 {
46-
w.Write(unauthroizedError)
4739
w.WriteHeader(http.StatusUnauthorized)
40+
w.Write(unauthroizedError)
4841
return
4942
}
5043

@@ -53,8 +46,8 @@ func AuthenticationMiddleware(next mock_api.MockEndpoint) http.Handler {
5346

5447
// check if the client ID is invalid or missing the proper token prefix
5548
if len(clientID) < 30 || prefix != "bearer" {
56-
w.Write(unauthroizedError)
5749
w.WriteHeader(http.StatusUnauthorized)
50+
w.Write(unauthroizedError)
5851
return
5952
}
6053

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,138 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package authentication
4+
5+
import (
6+
"context"
7+
"fmt"
8+
"log"
9+
"net/http"
10+
"net/http/httptest"
11+
"testing"
12+
"time"
13+
14+
"github.com/stretchr/testify/assert"
15+
"github.com/twitchdev/twitch-cli/internal/database"
16+
"github.com/twitchdev/twitch-cli/internal/util"
17+
"github.com/twitchdev/twitch-cli/test_setup"
18+
)
19+
20+
var a *assert.Assertions
21+
var ac = database.AuthenticationClient{ID: "1234", Secret: "1234", Name: "test_client", IsExtension: false}
22+
var token = "potato"
23+
var firstRun = true
24+
25+
func TestHasScope(t *testing.T) {
26+
a = test_setup.SetupTestEnv(t)
27+
28+
u := UserAuthentication{UserID: "1", Scopes: []string{"user:read:email"}}
29+
30+
a.Equal(true, u.HasScope("user:read:email"))
31+
a.Equal(false, u.HasScope("user:read"))
32+
33+
a.Equal(true, u.HasOneOfRequiredScope([]string{}))
34+
a.Equal(true, u.HasOneOfRequiredScope([]string{"user:read:email", "user:read"}))
35+
a.Equal(false, u.HasOneOfRequiredScope([]string{"user:read"}))
36+
37+
}
38+
39+
func TestNatchesBroadcasterIDParam(t *testing.T) {
40+
a = test_setup.SetupTestEnv(t)
41+
42+
req, _ := http.NewRequest(http.MethodGet, "http://google.com", nil)
43+
u := UserAuthentication{UserID: "1", Scopes: []string{"user:read:email"}}
44+
45+
q := req.URL.Query()
46+
q.Set("broadcaster_id", "2")
47+
req.URL.RawQuery = q.Encode()
48+
49+
a.Equal(false, u.MatchesBroadcasterIDParam(req))
50+
51+
q.Set("broadcaster_id", "1")
52+
req.URL.RawQuery = q.Encode()
53+
54+
a.Equal(true, u.MatchesBroadcasterIDParam(req))
55+
}
56+
57+
func TestAuthenticationMiddleware(t *testing.T) {
58+
a = test_setup.SetupTestEnv(t)
59+
ts := httptest.NewServer(baseMiddleware(AuthenticationMiddleware(testEndpoint{})))
60+
61+
req, _ := http.NewRequest(http.MethodGet, ts.URL+testEndpoint{}.Path(), nil)
62+
63+
resp, err := http.DefaultClient.Do(req)
64+
a.Nil(err)
65+
a.Equal(401, resp.StatusCode)
66+
67+
resp, err = http.DefaultClient.Do(req)
68+
a.Nil(err)
69+
a.Equal(401, resp.StatusCode)
70+
71+
req.Header.Set("Client-ID", ac.ID)
72+
resp, err = http.DefaultClient.Do(req)
73+
a.Nil(err)
74+
a.Equal(401, resp.StatusCode)
75+
76+
req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token))
77+
resp, err = http.DefaultClient.Do(req)
78+
a.Nil(err)
79+
a.Equal(200, resp.StatusCode)
80+
81+
req.Header.Set("Authorization", fmt.Sprintf("Bearer%v", token))
82+
resp, err = http.DefaultClient.Do(req)
83+
a.Nil(err)
84+
a.Equal(401, resp.StatusCode)
85+
86+
req.Header.Set("Authorization", fmt.Sprintf("Bearer"))
87+
resp, err = http.DefaultClient.Do(req)
88+
a.Nil(err)
89+
a.Equal(401, resp.StatusCode)
90+
}
91+
92+
func baseMiddleware(next http.Handler) http.Handler {
93+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
94+
ctx := context.Background()
95+
96+
// just stub it all
97+
db, err := database.NewConnection()
98+
if err != nil {
99+
log.Fatalf("Error connecting to database: %v", err.Error())
100+
return
101+
}
102+
if firstRun == true {
103+
ac, err = db.NewQuery(r, 100).InsertOrUpdateAuthenticationClient(ac, false)
104+
a.Nil(err)
105+
auth, err := db.NewQuery(r, 100).CreateAuthorization(database.Authorization{ClientID: ac.ID, UserID: "1", Scopes: "user:read:email bits:read", Token: token, ExpiresAt: util.GetTimestamp().Add(7 * 24 * time.Hour).Format(time.RFC3339)})
106+
token = auth.Token
107+
a.Nil(err)
108+
109+
firstRun = false
110+
}
111+
112+
defer db.DB.Close()
113+
114+
ctx = context.WithValue(ctx, "db", db)
115+
r = r.WithContext(ctx)
116+
117+
next.ServeHTTP(w, r)
118+
})
119+
}
120+
121+
type testEndpoint struct{}
122+
123+
func (e testEndpoint) Path() string { return "/endpoint" }
124+
125+
func (e testEndpoint) GetRequiredScopes(method string) []string {
126+
return []string{}
127+
}
128+
129+
func (e testEndpoint) ValidMethod(method string) bool {
130+
return true
131+
}
132+
133+
func (e testEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
134+
userCtx := r.Context().Value("auth").(UserAuthentication)
135+
136+
a.NotNil(userCtx)
137+
w.WriteHeader(200)
138+
}

internal/mock_auth/app_access_token.go

+4-6
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,6 @@ package mock_auth
44

55
import (
66
"encoding/json"
7-
"errors"
8-
"log"
97
"net/http"
108
"strings"
119
"time"
@@ -53,19 +51,19 @@ func (e AppAccessTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request
5351
}
5452

5553
if !areValidScopes(scopes, APP_ACCES_TOKEN) {
56-
log.Printf("%v", scopes)
57-
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Invalid scopes requested"))
54+
mock_errors.WriteBadRequest(w, "Invalid scopes requested")
5855
return
5956
}
6057

6158
res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: clientID, Secret: clientSecret})
6259
if err != nil {
63-
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
60+
mock_errors.WriteServerError(w, err.Error())
61+
return
6462
}
6563

6664
ac := res.Data.([]database.AuthenticationClient)
6765
if len(ac) == 0 {
68-
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Client ID/Secret invalid"))
66+
mock_errors.WriteBadRequest(w, "Client ID/Secret invalid")
6967
return
7068
}
7169

internal/mock_auth/mock_auth_test.go

+126
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,126 @@
1+
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2+
// SPDX-License-Identifier: Apache-2.0
3+
package mock_auth
4+
5+
import (
6+
"context"
7+
"log"
8+
"net/http"
9+
"net/http/httptest"
10+
"testing"
11+
12+
"github.com/stretchr/testify/assert"
13+
"github.com/twitchdev/twitch-cli/internal/database"
14+
"github.com/twitchdev/twitch-cli/internal/util"
15+
"github.com/twitchdev/twitch-cli/test_setup"
16+
)
17+
18+
var a *assert.Assertions
19+
var firstRun = true
20+
var ac = database.AuthenticationClient{ID: "222", Secret: "333", Name: "test_client", IsExtension: false}
21+
22+
func TestAreValidScopes(t *testing.T) {
23+
a := test_setup.SetupTestEnv(t)
24+
25+
a.Equal(true, areValidScopes([]string{"user:read:email"}, USER_ACCESS_TOKEN))
26+
a.Equal(false, areValidScopes([]string{"user:read:email"}, APP_ACCES_TOKEN))
27+
}
28+
29+
func TestUserToken(t *testing.T) {
30+
a = test_setup.SetupTestEnv(t)
31+
ts := httptest.NewServer(baseMiddleware(UserTokenEndpoint{}))
32+
33+
req, _ := http.NewRequest(http.MethodPost, ts.URL+UserTokenEndpoint{}.Path(), nil)
34+
q := req.URL.Query()
35+
36+
req.URL.RawQuery = q.Encode()
37+
resp, err := http.DefaultClient.Do(req)
38+
a.Nil(err, err)
39+
a.Equal(400, resp.StatusCode)
40+
41+
// valid values
42+
q.Set("client_id", ac.ID)
43+
q.Set("client_secret", ac.Secret)
44+
q.Set("grant_type", "user_token")
45+
q.Set("user_id", "1")
46+
47+
q.Set("scope", "potato")
48+
req.URL.RawQuery = q.Encode()
49+
resp, err = http.DefaultClient.Do(req)
50+
a.Nil(err)
51+
a.Equal(400, resp.StatusCode)
52+
53+
q.Set("scope", "")
54+
req.URL.RawQuery = q.Encode()
55+
resp, err = http.DefaultClient.Do(req)
56+
a.Nil(err)
57+
a.Equal(200, resp.StatusCode)
58+
59+
q.Set("client_id", "1234")
60+
req.URL.RawQuery = q.Encode()
61+
resp, err = http.DefaultClient.Do(req)
62+
a.Nil(err)
63+
a.Equal(400, resp.StatusCode)
64+
65+
q.Set("client_id", ac.ID)
66+
q.Set("user_id", util.RandomGUID())
67+
req.URL.RawQuery = q.Encode()
68+
resp, err = http.DefaultClient.Do(req)
69+
a.Nil(err)
70+
a.Equal(400, resp.StatusCode)
71+
}
72+
73+
func TestAppAccessToken(t *testing.T) {
74+
a = test_setup.SetupTestEnv(t)
75+
ts := httptest.NewServer(baseMiddleware(AppAccessTokenEndpoint{}))
76+
77+
req, _ := http.NewRequest(http.MethodPost, ts.URL+AppAccessTokenEndpoint{}.Path(), nil)
78+
q := req.URL.Query()
79+
80+
req.URL.RawQuery = q.Encode()
81+
resp, err := http.DefaultClient.Do(req)
82+
a.Nil(err, err)
83+
a.Equal(400, resp.StatusCode)
84+
85+
// valid values
86+
q.Set("client_id", ac.ID)
87+
q.Set("client_secret", ac.Secret)
88+
q.Set("grant_type", "client_credentials")
89+
90+
q.Set("scope", "potato")
91+
req.URL.RawQuery = q.Encode()
92+
resp, err = http.DefaultClient.Do(req)
93+
a.Nil(err)
94+
a.Equal(400, resp.StatusCode)
95+
96+
q.Set("scope", "")
97+
req.URL.RawQuery = q.Encode()
98+
resp, err = http.DefaultClient.Do(req)
99+
a.Nil(err)
100+
a.Equal(200, resp.StatusCode)
101+
}
102+
func baseMiddleware(next http.Handler) http.Handler {
103+
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
104+
ctx := context.Background()
105+
106+
// just stub it all
107+
db, err := database.NewConnection()
108+
if err != nil {
109+
log.Fatalf("Error connecting to database: %v", err.Error())
110+
return
111+
}
112+
if firstRun == true {
113+
ac, err = db.NewQuery(r, 100).InsertOrUpdateAuthenticationClient(ac, false)
114+
a.Nil(err, err)
115+
116+
firstRun = false
117+
}
118+
119+
defer db.DB.Close()
120+
121+
ctx = context.WithValue(ctx, "db", db)
122+
r = r.WithContext(ctx)
123+
124+
next.ServeHTTP(w, r)
125+
})
126+
}

internal/mock_auth/user_token.go

+6-7
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ package mock_auth
44

55
import (
66
"encoding/json"
7-
"errors"
87
"net/http"
98
"strings"
109
"time"
@@ -32,36 +31,36 @@ func (e UserTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
3231
scopes := strings.Split(scope, " ")
3332

3433
if clientID == "" || clientSecret == "" || grantType != "user_token" || userID == "" {
35-
w.WriteHeader(http.StatusBadRequest)
34+
mock_errors.WriteBadRequest(w, "missing required parameter")
3635
return
3736
}
3837

3938
if areValidScopes(scopes, USER_ACCESS_TOKEN) != true {
40-
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Invalid scopes requested"))
39+
mock_errors.WriteBadRequest(w, "Invalid scopes requested")
4140
return
4241
}
4342

4443
res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: clientID, Secret: clientSecret})
4544
if err != nil {
46-
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
45+
mock_errors.WriteServerError(w, err.Error())
4746
return
4847
}
4948

5049
ac := res.Data.([]database.AuthenticationClient)
5150
if len(ac) == 0 {
52-
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Client ID/Secret invalid"))
51+
mock_errors.WriteBadRequest(w, "Client ID/Secret invalid")
5352
return
5453
}
5554

5655
res, err = db.NewQuery(r, 10).GetUsers(database.User{ID: userID})
5756
if err != nil {
58-
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
57+
mock_errors.WriteServerError(w, err.Error())
5958
return
6059
}
6160

6261
users := res.Data.([]database.User)
6362
if len(users) == 0 {
64-
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "User ID invalid"))
63+
mock_errors.WriteBadRequest(w, "User ID invalid")
6564
return
6665
}
6766

0 commit comments

Comments
 (0)