Skip to content

Commit

Permalink
finished tests
Browse files Browse the repository at this point in the history
  • Loading branch information
lleadbet committed Jun 14, 2021
1 parent 647a3b8 commit 9721fd8
Show file tree
Hide file tree
Showing 6 changed files with 276 additions and 22 deletions.
11 changes: 2 additions & 9 deletions internal/mock_api/authentication/authentication.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,13 @@ func AuthenticationMiddleware(next mock_api.MockEndpoint) http.Handler {
w.WriteHeader(http.StatusMethodNotAllowed)
return
}
if len(r.URL.Query()["skip_auth"]) > 0 && r.URL.Query()["skip_auth"][0] == "true" {
fakeAuth := UserAuthentication{}
r = r.WithContext(context.WithValue(r.Context(), "auth", fakeAuth))
next.ServeHTTP(w, r)
log.Printf("Skipping auth...")
return
}

clientID := r.Header.Get("Client-ID")
bearerToken := r.Header.Get("Authorization")
unauthroizedError := mock_errors.GetErrorBytes(http.StatusUnauthorized, errors.New("Unauthorized"), "Missing Client ID or OAuth token")
if clientID == "" || bearerToken == "" || len(bearerToken) < 7 {
w.Write(unauthroizedError)
w.WriteHeader(http.StatusUnauthorized)
w.Write(unauthroizedError)
return
}

Expand All @@ -53,8 +46,8 @@ func AuthenticationMiddleware(next mock_api.MockEndpoint) http.Handler {

// check if the client ID is invalid or missing the proper token prefix
if len(clientID) < 30 || prefix != "bearer" {
w.Write(unauthroizedError)
w.WriteHeader(http.StatusUnauthorized)
w.Write(unauthroizedError)
return
}

Expand Down
138 changes: 138 additions & 0 deletions internal/mock_api/authentication/authentication_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,138 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package authentication

import (
"context"
"fmt"
"log"
"net/http"
"net/http/httptest"
"testing"
"time"

"github.com/stretchr/testify/assert"
"github.com/twitchdev/twitch-cli/internal/database"
"github.com/twitchdev/twitch-cli/internal/util"
"github.com/twitchdev/twitch-cli/test_setup"
)

var a *assert.Assertions
var ac = database.AuthenticationClient{ID: "1234", Secret: "1234", Name: "test_client", IsExtension: false}
var token = "potato"
var firstRun = true

func TestHasScope(t *testing.T) {
a = test_setup.SetupTestEnv(t)

u := UserAuthentication{UserID: "1", Scopes: []string{"user:read:email"}}

a.Equal(true, u.HasScope("user:read:email"))
a.Equal(false, u.HasScope("user:read"))

a.Equal(true, u.HasOneOfRequiredScope([]string{}))
a.Equal(true, u.HasOneOfRequiredScope([]string{"user:read:email", "user:read"}))
a.Equal(false, u.HasOneOfRequiredScope([]string{"user:read"}))

}

func TestNatchesBroadcasterIDParam(t *testing.T) {
a = test_setup.SetupTestEnv(t)

req, _ := http.NewRequest(http.MethodGet, "http://google.com", nil)
u := UserAuthentication{UserID: "1", Scopes: []string{"user:read:email"}}

q := req.URL.Query()
q.Set("broadcaster_id", "2")
req.URL.RawQuery = q.Encode()

a.Equal(false, u.MatchesBroadcasterIDParam(req))

q.Set("broadcaster_id", "1")
req.URL.RawQuery = q.Encode()

a.Equal(true, u.MatchesBroadcasterIDParam(req))
}

func TestAuthenticationMiddleware(t *testing.T) {
a = test_setup.SetupTestEnv(t)
ts := httptest.NewServer(baseMiddleware(AuthenticationMiddleware(testEndpoint{})))

req, _ := http.NewRequest(http.MethodGet, ts.URL+testEndpoint{}.Path(), nil)

resp, err := http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(401, resp.StatusCode)

resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(401, resp.StatusCode)

req.Header.Set("Client-ID", ac.ID)
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(401, resp.StatusCode)

req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", token))
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(200, resp.StatusCode)

req.Header.Set("Authorization", fmt.Sprintf("Bearer%v", token))
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(401, resp.StatusCode)

req.Header.Set("Authorization", fmt.Sprintf("Bearer"))
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(401, resp.StatusCode)
}

func baseMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()

// just stub it all
db, err := database.NewConnection()
if err != nil {
log.Fatalf("Error connecting to database: %v", err.Error())
return
}
if firstRun == true {
ac, err = db.NewQuery(r, 100).InsertOrUpdateAuthenticationClient(ac, false)
a.Nil(err)
auth, err := db.NewQuery(r, 100).CreateAuthorization(database.Authorization{ClientID: ac.ID, UserID: "1", Scopes: "user:read:email bits:read", Token: token, ExpiresAt: util.GetTimestamp().Add(7 * 24 * time.Hour).Format(time.RFC3339)})
token = auth.Token
a.Nil(err)

firstRun = false
}

defer db.DB.Close()

ctx = context.WithValue(ctx, "db", db)
r = r.WithContext(ctx)

next.ServeHTTP(w, r)
})
}

type testEndpoint struct{}

func (e testEndpoint) Path() string { return "/endpoint" }

func (e testEndpoint) GetRequiredScopes(method string) []string {
return []string{}
}

func (e testEndpoint) ValidMethod(method string) bool {
return true
}

func (e testEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
userCtx := r.Context().Value("auth").(UserAuthentication)

a.NotNil(userCtx)
w.WriteHeader(200)
}
10 changes: 4 additions & 6 deletions internal/mock_auth/app_access_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,6 @@ package mock_auth

import (
"encoding/json"
"errors"
"log"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -53,19 +51,19 @@ func (e AppAccessTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request
}

if !areValidScopes(scopes, APP_ACCES_TOKEN) {
log.Printf("%v", scopes)
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Invalid scopes requested"))
mock_errors.WriteBadRequest(w, "Invalid scopes requested")
return
}

res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: clientID, Secret: clientSecret})
if err != nil {
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
mock_errors.WriteServerError(w, err.Error())
return
}

ac := res.Data.([]database.AuthenticationClient)
if len(ac) == 0 {
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Client ID/Secret invalid"))
mock_errors.WriteBadRequest(w, "Client ID/Secret invalid")
return
}

Expand Down
126 changes: 126 additions & 0 deletions internal/mock_auth/mock_auth_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
// SPDX-License-Identifier: Apache-2.0
package mock_auth

import (
"context"
"log"
"net/http"
"net/http/httptest"
"testing"

"github.com/stretchr/testify/assert"
"github.com/twitchdev/twitch-cli/internal/database"
"github.com/twitchdev/twitch-cli/internal/util"
"github.com/twitchdev/twitch-cli/test_setup"
)

var a *assert.Assertions
var firstRun = true
var ac = database.AuthenticationClient{ID: "222", Secret: "333", Name: "test_client", IsExtension: false}

func TestAreValidScopes(t *testing.T) {
a := test_setup.SetupTestEnv(t)

a.Equal(true, areValidScopes([]string{"user:read:email"}, USER_ACCESS_TOKEN))
a.Equal(false, areValidScopes([]string{"user:read:email"}, APP_ACCES_TOKEN))
}

func TestUserToken(t *testing.T) {
a = test_setup.SetupTestEnv(t)
ts := httptest.NewServer(baseMiddleware(UserTokenEndpoint{}))

req, _ := http.NewRequest(http.MethodPost, ts.URL+UserTokenEndpoint{}.Path(), nil)
q := req.URL.Query()

req.URL.RawQuery = q.Encode()
resp, err := http.DefaultClient.Do(req)
a.Nil(err, err)
a.Equal(400, resp.StatusCode)

// valid values
q.Set("client_id", ac.ID)
q.Set("client_secret", ac.Secret)
q.Set("grant_type", "user_token")
q.Set("user_id", "1")

q.Set("scope", "potato")
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(400, resp.StatusCode)

q.Set("scope", "")
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(200, resp.StatusCode)

q.Set("client_id", "1234")
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(400, resp.StatusCode)

q.Set("client_id", ac.ID)
q.Set("user_id", util.RandomGUID())
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(400, resp.StatusCode)
}

func TestAppAccessToken(t *testing.T) {
a = test_setup.SetupTestEnv(t)
ts := httptest.NewServer(baseMiddleware(AppAccessTokenEndpoint{}))

req, _ := http.NewRequest(http.MethodPost, ts.URL+AppAccessTokenEndpoint{}.Path(), nil)
q := req.URL.Query()

req.URL.RawQuery = q.Encode()
resp, err := http.DefaultClient.Do(req)
a.Nil(err, err)
a.Equal(400, resp.StatusCode)

// valid values
q.Set("client_id", ac.ID)
q.Set("client_secret", ac.Secret)
q.Set("grant_type", "client_credentials")

q.Set("scope", "potato")
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(400, resp.StatusCode)

q.Set("scope", "")
req.URL.RawQuery = q.Encode()
resp, err = http.DefaultClient.Do(req)
a.Nil(err)
a.Equal(200, resp.StatusCode)
}
func baseMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ctx := context.Background()

// just stub it all
db, err := database.NewConnection()
if err != nil {
log.Fatalf("Error connecting to database: %v", err.Error())
return
}
if firstRun == true {
ac, err = db.NewQuery(r, 100).InsertOrUpdateAuthenticationClient(ac, false)
a.Nil(err, err)

firstRun = false
}

defer db.DB.Close()

ctx = context.WithValue(ctx, "db", db)
r = r.WithContext(ctx)

next.ServeHTTP(w, r)
})
}
13 changes: 6 additions & 7 deletions internal/mock_auth/user_token.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ package mock_auth

import (
"encoding/json"
"errors"
"net/http"
"strings"
"time"
Expand Down Expand Up @@ -32,36 +31,36 @@ func (e UserTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) {
scopes := strings.Split(scope, " ")

if clientID == "" || clientSecret == "" || grantType != "user_token" || userID == "" {
w.WriteHeader(http.StatusBadRequest)
mock_errors.WriteBadRequest(w, "missing required parameter")
return
}

if areValidScopes(scopes, USER_ACCESS_TOKEN) != true {
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Invalid scopes requested"))
mock_errors.WriteBadRequest(w, "Invalid scopes requested")
return
}

res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: clientID, Secret: clientSecret})
if err != nil {
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
mock_errors.WriteServerError(w, err.Error())
return
}

ac := res.Data.([]database.AuthenticationClient)
if len(ac) == 0 {
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "Client ID/Secret invalid"))
mock_errors.WriteBadRequest(w, "Client ID/Secret invalid")
return
}

res, err = db.NewQuery(r, 10).GetUsers(database.User{ID: userID})
if err != nil {
w.Write(mock_errors.GetErrorBytes(http.StatusInternalServerError, err, err.Error()))
mock_errors.WriteServerError(w, err.Error())
return
}

users := res.Data.([]database.User)
if len(users) == 0 {
w.Write(mock_errors.GetErrorBytes(http.StatusBadRequest, errors.New("Unauthorized"), "User ID invalid"))
mock_errors.WriteBadRequest(w, "User ID invalid")
return
}

Expand Down

0 comments on commit 9721fd8

Please sign in to comment.