From 0c955c5b54fd36afa9fc9183c5febea6d1a226c2 Mon Sep 17 00:00:00 2001 From: Forrest <30576607+fspmarshall@users.noreply.github.com> Date: Mon, 24 Apr 2023 10:03:18 -0700 Subject: [PATCH] fix github url formatting (#25089) --- lib/auth/github.go | 10 ++++++++-- lib/auth/github_test.go | 28 ++++++++++++++++++++++++++++ 2 files changed, 36 insertions(+), 2 deletions(-) diff --git a/lib/auth/github.go b/lib/auth/github.go index 7ae335989faec..ea9a1539705b6 100644 --- a/lib/auth/github.go +++ b/lib/auth/github.go @@ -24,6 +24,7 @@ import ( "io" "net/http" "net/url" + "strings" "time" "github.com/coreos/go-oidc/oauth2" @@ -950,8 +951,8 @@ func (c *githubAPIClient) getTeams() ([]teamResponse, error) { } // get makes a GET request to the provided URL using the client's token for auth -func (c *githubAPIClient) get(url string) ([]byte, string, error) { - request, err := http.NewRequest("GET", fmt.Sprintf("https://api.%s/%s", c.endpointHostname, url), nil) +func (c *githubAPIClient) get(page string) ([]byte, string, error) { + request, err := http.NewRequest("GET", formatGithubURL(c.endpointHostname, page), nil) if err != nil { return nil, "", trace.Wrap(err) } @@ -977,6 +978,11 @@ func (c *githubAPIClient) get(url string) ([]byte, string, error) { return bytes, wls.NextPage, nil } +// formatGithubURL is a helper for formatting github api request URLs. +func formatGithubURL(host string, path string) string { + return fmt.Sprintf("https://%s/%s", host, strings.TrimPrefix(path, "/")) +} + const ( // GithubAuthPath is the GitHub authorization endpoint GithubAuthPath = "login/oauth/authorize" diff --git a/lib/auth/github_test.go b/lib/auth/github_test.go index e311c3df638b3..7f3445d894d65 100644 --- a/lib/auth/github_test.go +++ b/lib/auth/github_test.go @@ -532,3 +532,31 @@ func TestCheckGithubOrgSSOSupport(t *testing.T) { }) } } + +func TestGithubURLFormat(t *testing.T) { + tts := []struct { + host string + path string + expect string + }{ + { + host: "example.com", + path: "foo/bar", + expect: "https://example.com/foo/bar", + }, + { + host: "example.com", + path: "/foo/bar?spam=eggs", + expect: "https://example.com/foo/bar?spam=eggs", + }, + { + host: "example.com", + path: "/foo/bar", + expect: "https://example.com/foo/bar", + }, + } + + for _, tt := range tts { + require.Equal(t, tt.expect, formatGithubURL(tt.host, tt.path)) + } +}