From 723f5fc5f08b5b2c2bcb41503573f59dbe1c6593 Mon Sep 17 00:00:00 2001 From: lleadbet Date: Mon, 16 Aug 2021 12:47:49 -0400 Subject: [PATCH 1/2] adding body param support for client credentials --- internal/mock_auth/app_access_token.go | 46 ++++++++++++++++++++++---- 1 file changed, 39 insertions(+), 7 deletions(-) diff --git a/internal/mock_auth/app_access_token.go b/internal/mock_auth/app_access_token.go index 29d43f6c..c7e051d1 100644 --- a/internal/mock_auth/app_access_token.go +++ b/internal/mock_auth/app_access_token.go @@ -15,6 +15,13 @@ import ( type AppAccessTokenEndpoint struct{} +type AppAccessTokenRequestBody struct { + ClientID string `json:"client_id"` + ClientSecret string `json:"client_secret"` + GrantType string `json:"grant_type"` + Scope string `json:"scope"` +} + type AppAccessTokenEndpointResposne struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` @@ -33,12 +40,37 @@ func (e AppAccessTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request return } - clientID := r.URL.Query().Get("client_id") - clientSecret := r.URL.Query().Get("client_secret") - grantType := r.URL.Query().Get("grant_type") - scope := r.URL.Query().Get("scope") - scopes := strings.Split(scope, " ") - if clientID == "" || clientSecret == "" || grantType != "client_credentials" { + params := AppAccessTokenRequestBody{ + ClientID: r.URL.Query().Get("client_id"), + ClientSecret: r.URL.Query().Get("client_secret"), + GrantType: r.URL.Query().Get("grant_type"), + Scope: r.URL.Query().Get("scope"), + } + + if r.Header.Get("Content-Type") == "application/x-www-form-urlencoded" { + err := r.ParseForm() + if err != nil { + mock_errors.WriteServerError(w, err.Error()) + return + } + + if r.Form.Get("client_id") != "" { + params.ClientID = r.Form.Get("client_id") + } + if r.Form.Get("client_secret") != "" { + params.ClientSecret = r.Form.Get("client_secret") + } + if r.Form.Get("grant_type") != "" { + params.GrantType = r.Form.Get("grant_type") + } + if r.Form.Get("scope") != "" { + params.Scope = r.Form.Get("scope") + } + } + + scopes := strings.Split(params.Scope, " ") + + if params.ClientID == "" || params.ClientSecret == "" || params.GrantType != "client_credentials" { w.WriteHeader(http.StatusBadRequest) return } @@ -55,7 +87,7 @@ func (e AppAccessTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request return } - res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: clientID, Secret: clientSecret}) + res, err := db.NewQuery(r, 10).GetAuthenticationClient(database.AuthenticationClient{ID: params.ClientID, Secret: params.ClientSecret}) if err != nil { mock_errors.WriteServerError(w, err.Error()) return From c80ac483da897f7d8770e295d337faf4a94998bd Mon Sep 17 00:00:00 2001 From: lleadbet Date: Mon, 16 Aug 2021 14:18:20 -0400 Subject: [PATCH 2/2] fixing typo + add validate endpoint --- .../mock_api/endpoints/bits/cheermotes.go | 4 +- internal/mock_api/endpoints/clips/clips.go | 10 +-- internal/mock_api/endpoints/polls/polls.go | 12 +-- .../endpoints/predictions/predictions.go | 6 +- .../mock_api/endpoints/streams/streamkey.go | 4 +- .../mock_api/endpoints/teams/channel_teams.go | 6 +- internal/mock_auth/app_access_token.go | 4 +- internal/mock_auth/mock_auth.go | 3 +- internal/mock_auth/mock_auth_test.go | 48 +++++++++++ internal/mock_auth/user_token.go | 2 +- internal/mock_auth/validate.go | 85 +++++++++++++++++++ 11 files changed, 159 insertions(+), 25 deletions(-) create mode 100644 internal/mock_auth/validate.go diff --git a/internal/mock_api/endpoints/bits/cheermotes.go b/internal/mock_api/endpoints/bits/cheermotes.go index 1be273cd..887d0b94 100644 --- a/internal/mock_api/endpoints/bits/cheermotes.go +++ b/internal/mock_api/endpoints/bits/cheermotes.go @@ -128,8 +128,8 @@ func getCheermotes(w http.ResponseWriter, r *http.Request) { cheermoteBody = append(cheermoteBody, cheermote) } - resposne, _ := json.Marshal(cheermoteBody) - w.Write(resposne) + response, _ := json.Marshal(cheermoteBody) + w.Write(response) } func generateCheermoteImageSizes(prefix string, theme string, imageType string, bits int) CheermoteImageSizes { diff --git a/internal/mock_api/endpoints/clips/clips.go b/internal/mock_api/endpoints/clips/clips.go index 9af8d60e..594331b5 100644 --- a/internal/mock_api/endpoints/clips/clips.go +++ b/internal/mock_api/endpoints/clips/clips.go @@ -92,21 +92,21 @@ func getClips(w http.ResponseWriter, r *http.Request) { } clips := dbr.Data.([]database.Clip) - apiResposne := models.APIResponse{ + apiResponse := models.APIResponse{ Data: clips, } - if len(apiResposne.Data.([]database.Clip)) == 0 { - apiResposne.Data = []string{} + if len(apiResponse.Data.([]database.Clip)) == 0 { + apiResponse.Data = []string{} } if dbr.Limit == len(dbr.Data.([]database.Clip)) { - apiResposne.Pagination = &models.APIPagination{ + apiResponse.Pagination = &models.APIPagination{ Cursor: dbr.Cursor, } } - bytes, _ := json.Marshal(apiResposne) + bytes, _ := json.Marshal(apiResponse) w.Write(bytes) } diff --git a/internal/mock_api/endpoints/polls/polls.go b/internal/mock_api/endpoints/polls/polls.go index bacb8000..7c1f00ea 100644 --- a/internal/mock_api/endpoints/polls/polls.go +++ b/internal/mock_api/endpoints/polls/polls.go @@ -110,17 +110,17 @@ func getPolls(w http.ResponseWriter, r *http.Request) { polls = append(polls, dbr.Data.([]database.Poll)...) } - apiResposne := models.APIResponse{ + apiResponse := models.APIResponse{ Data: polls, } if dbr != nil && dbr.Cursor != "" { - apiResposne.Pagination = &models.APIPagination{ + apiResponse.Pagination = &models.APIPagination{ Cursor: dbr.Cursor, } } - bytes, _ := json.Marshal(apiResposne) + bytes, _ := json.Marshal(apiResponse) w.Write(bytes) } @@ -231,16 +231,16 @@ func patchPolls(w http.ResponseWriter, r *http.Request) { return } - apiResposne := models.APIResponse{ + apiResponse := models.APIResponse{ Data: dbr.Data, } if dbr.Cursor != "" { - apiResposne.Pagination = &models.APIPagination{ + apiResponse.Pagination = &models.APIPagination{ Cursor: dbr.Cursor, } } - bytes, _ := json.Marshal(apiResposne) + bytes, _ := json.Marshal(apiResponse) w.Write(bytes) } diff --git a/internal/mock_api/endpoints/predictions/predictions.go b/internal/mock_api/endpoints/predictions/predictions.go index bba08f0c..0af8a9c7 100644 --- a/internal/mock_api/endpoints/predictions/predictions.go +++ b/internal/mock_api/endpoints/predictions/predictions.go @@ -110,17 +110,17 @@ func getPredictions(w http.ResponseWriter, r *http.Request) { predictions = append(predictions, dbr.Data.([]database.Prediction)...) } - apiResposne := models.APIResponse{ + apiResponse := models.APIResponse{ Data: predictions, } if dbr != nil && dbr.Cursor != "" { - apiResposne.Pagination = &models.APIPagination{ + apiResponse.Pagination = &models.APIPagination{ Cursor: dbr.Cursor, } } - bytes, _ := json.Marshal(apiResposne) + bytes, _ := json.Marshal(apiResponse) w.Write(bytes) } diff --git a/internal/mock_api/endpoints/streams/streamkey.go b/internal/mock_api/endpoints/streams/streamkey.go index 411dda86..68974153 100644 --- a/internal/mock_api/endpoints/streams/streamkey.go +++ b/internal/mock_api/endpoints/streams/streamkey.go @@ -32,7 +32,7 @@ var streamKeyScopesByMethod = map[string][]string{ type StreamKey struct{} -type StreamKeyResposne struct { +type StreamKeyResponse struct { StreamKey string `json:"stream_key"` } @@ -66,7 +66,7 @@ func getStreamKey(w http.ResponseWriter, r *http.Request) { return } - streamKeys := []StreamKeyResposne{{StreamKey: fmt.Sprintf("live_%v_%v", userCtx.UserID, util.RandomGUID())}} + streamKeys := []StreamKeyResponse{{StreamKey: fmt.Sprintf("live_%v_%v", userCtx.UserID, util.RandomGUID())}} bytes, _ := json.Marshal(models.APIResponse{Data: streamKeys}) w.Write(bytes) diff --git a/internal/mock_api/endpoints/teams/channel_teams.go b/internal/mock_api/endpoints/teams/channel_teams.go index 10753d8e..16cb110a 100644 --- a/internal/mock_api/endpoints/teams/channel_teams.go +++ b/internal/mock_api/endpoints/teams/channel_teams.go @@ -28,7 +28,7 @@ var channelTeamsScopesByMethod = map[string][]string{ } type ChannelTeams struct{} -type ChannelTeamResposne struct { +type ChannelTeamResponse struct { ID string `json:"id"` BackgroundImageUrl *string `json:"background_image_url"` Banner *string `json:"banner"` @@ -77,9 +77,9 @@ func getChannelTeams(w http.ResponseWriter, r *http.Request) { if len(team) == 0 { dbr.Data = make([]database.Team, 0) } - response := []ChannelTeamResposne{} + response := []ChannelTeamResponse{} for _, t := range team { - response = append(response, ChannelTeamResposne{ + response = append(response, ChannelTeamResponse{ ID: t.ID, Info: t.Info, BackgroundImageUrl: t.BackgroundImageUrl, diff --git a/internal/mock_auth/app_access_token.go b/internal/mock_auth/app_access_token.go index c7e051d1..88f4d55e 100644 --- a/internal/mock_auth/app_access_token.go +++ b/internal/mock_auth/app_access_token.go @@ -22,7 +22,7 @@ type AppAccessTokenRequestBody struct { Scope string `json:"scope"` } -type AppAccessTokenEndpointResposne struct { +type AppAccessTokenEndpointResponse struct { AccessToken string `json:"access_token"` RefreshToken string `json:"refresh_token"` ExpiresIn int `json:"expires_in"` @@ -111,7 +111,7 @@ func (e AppAccessTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request return } ea, _ := time.Parse(time.RFC3339, a.ExpiresAt) - ater := AppAccessTokenEndpointResposne{ + ater := AppAccessTokenEndpointResponse{ AccessToken: auth.Token, RefreshToken: "", ExpiresIn: int(ea.Sub(time.Now().UTC()).Seconds()), diff --git a/internal/mock_auth/mock_auth.go b/internal/mock_auth/mock_auth.go index f3c91289..3c56d82e 100644 --- a/internal/mock_auth/mock_auth.go +++ b/internal/mock_auth/mock_auth.go @@ -58,8 +58,9 @@ var validScopesByTokenType = map[string]map[string]bool{ func All() []AuthEndpoint { return []AuthEndpoint{ - UserTokenEndpoint{}, AppAccessTokenEndpoint{}, + UserTokenEndpoint{}, + ValidateTokenEndpoint{}, } } diff --git a/internal/mock_auth/mock_auth_test.go b/internal/mock_auth/mock_auth_test.go index f651b86b..2b7c14f3 100644 --- a/internal/mock_auth/mock_auth_test.go +++ b/internal/mock_auth/mock_auth_test.go @@ -4,10 +4,13 @@ package mock_auth import ( "context" + "fmt" "log" "net/http" "net/http/httptest" + "os" "testing" + "time" "github.com/stretchr/testify/assert" "github.com/twitchdev/twitch-cli/internal/database" @@ -19,6 +22,10 @@ var a *assert.Assertions var firstRun = true var ac = database.AuthenticationClient{ID: "222", Secret: "333", Name: "test_client", IsExtension: false} +func TestMain(m *testing.M) { + + os.Exit(m.Run()) +} func TestAreValidScopes(t *testing.T) { a := test_setup.SetupTestEnv(t) @@ -70,6 +77,46 @@ func TestUserToken(t *testing.T) { a.Equal(400, resp.StatusCode) } +func TestValidateToken(t *testing.T) { + a = test_setup.SetupTestEnv(t) + ts := httptest.NewServer(baseMiddleware(ValidateTokenEndpoint{})) + + req, _ := http.NewRequest(http.MethodGet, ts.URL+ValidateTokenEndpoint{}.Path(), nil) + resp, err := http.DefaultClient.Do(req) + a.Nil(err, err) + a.Equal(401, resp.StatusCode) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", "auth.Token")) + resp, err = http.DefaultClient.Do(req) + a.Nil(err, err) + a.Equal(401, resp.StatusCode) + + db, err := database.NewConnection() + a.Nil(err, err) + defer db.DB.Close() + + auth, err := db.NewQuery(nil, 0).CreateAuthorization(database.Authorization{ + ClientID: ac.ID, + ExpiresAt: util.GetTimestamp().Add(time.Hour * 4).Format(time.RFC3339), + Scopes: "", + }) + + req.Header.Set("Authorization", fmt.Sprintf("Bearer %v", auth.Token)) + resp, err = http.DefaultClient.Do(req) + a.Nil(err, err) + a.Equal(200, resp.StatusCode) + + auth, err = db.NewQuery(nil, 0).CreateAuthorization(database.Authorization{ + ClientID: ac.ID, + ExpiresAt: util.GetTimestamp().Add(time.Hour * 4).Format(time.RFC3339), + Scopes: "user:read:email", + UserID: "1", + }) + req.Header.Set("Authorization", fmt.Sprintf("Oauth %v", auth.Token)) + resp, err = http.DefaultClient.Do(req) + a.Nil(err, err) + a.Equal(200, resp.StatusCode) +} func TestAppAccessToken(t *testing.T) { a = test_setup.SetupTestEnv(t) ts := httptest.NewServer(baseMiddleware(AppAccessTokenEndpoint{})) @@ -99,6 +146,7 @@ func TestAppAccessToken(t *testing.T) { a.Nil(err) a.Equal(200, resp.StatusCode) } + func baseMiddleware(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { ctx := context.Background() diff --git a/internal/mock_auth/user_token.go b/internal/mock_auth/user_token.go index 8dc0938d..3063e564 100644 --- a/internal/mock_auth/user_token.go +++ b/internal/mock_auth/user_token.go @@ -77,7 +77,7 @@ func (e UserTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } ea, _ := time.Parse(time.RFC3339, a.ExpiresAt) - ater := AppAccessTokenEndpointResposne{ + ater := AppAccessTokenEndpointResponse{ AccessToken: auth.Token, RefreshToken: "", ExpiresIn: int(ea.Sub(time.Now().UTC()).Seconds()), diff --git a/internal/mock_auth/validate.go b/internal/mock_auth/validate.go new file mode 100644 index 00000000..28f171db --- /dev/null +++ b/internal/mock_auth/validate.go @@ -0,0 +1,85 @@ +// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved. +// SPDX-License-Identifier: Apache-2.0 +package mock_auth + +import ( + "encoding/json" + "net/http" + "strings" + "time" + + "github.com/twitchdev/twitch-cli/internal/database" +) + +type ValidateTokenEndpoint struct{} + +type ValidateTokenEndpointResponse struct { + ClientID string `json:"client_id"` + UserID string `json:"user_id,omitempty"` + UserLogin string `json:"login,omitempty"` + ExpiresIn int `json:"expires_in"` + Scopes []string `json:"scopes"` +} + +func (e ValidateTokenEndpoint) Path() string { return "/validate" } + +func (e ValidateTokenEndpoint) ServeHTTP(w http.ResponseWriter, r *http.Request) { + db = r.Context().Value("db").(database.CLIDatabase) + + if r.Method != http.MethodGet { + w.WriteHeader(http.StatusMethodNotAllowed) + return + } + + tokenHeader := r.Header.Get("Authorization") + if tokenHeader == "" || len(tokenHeader) < 7 { + w.WriteHeader(http.StatusUnauthorized) + return + } + + token := "" + // handle prefixes + h := strings.ToLower(tokenHeader) + if strings.HasPrefix(h, "oauth ") { + token = tokenHeader[6:] + } else if strings.HasPrefix(h, "bearer ") { + token = tokenHeader[7:] + } + println(token) + auth, err := db.NewQuery(r, 100).GetAuthorizationByToken(token) + if err != nil || auth.ID == 0 { + w.WriteHeader(http.StatusUnauthorized) + return + } + + expiresAt, _ := time.Parse(time.RFC3339, auth.ExpiresAt) + + diff := expiresAt.Sub(time.Now()) + + scopes := []string{} + for _, s := range strings.Split(auth.Scopes, " ") { + if s != "" { + scopes = append(scopes, s) + } + } + resp := ValidateTokenEndpointResponse{ + ClientID: auth.ClientID, + UserID: auth.UserID, + Scopes: scopes, + ExpiresIn: int(diff.Seconds()), + } + if auth.UserID != "" { + user, err := db.NewQuery(r, 100).GetUser(database.User{ID: auth.UserID}) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } + resp.UserLogin = user.UserLogin + } + + err = json.NewEncoder(w).Encode(resp) + if err != nil { + w.WriteHeader(http.StatusInternalServerError) + return + } +}