diff --git a/backend/mock/service/githubmock/githubmock.go b/backend/mock/service/githubmock/githubmock.go index 8e6d7e4d96..eed6eefbc0 100644 --- a/backend/mock/service/githubmock/githubmock.go +++ b/backend/mock/service/githubmock/githubmock.go @@ -2,6 +2,8 @@ package githubmock import ( "context" + "fmt" + "strconv" "github.com/golang/protobuf/ptypes/any" githubv3 "github.com/google/go-github/v37/github" @@ -63,6 +65,17 @@ func (s svc) ListOrganizations(ctx context.Context, user string) ([]*githubv3.Or }, nil } +func (s svc) ListPullRequestsWithCommit(ctx context.Context, ref *github.RemoteRef, sha string, opts *githubv3.PullRequestListOptions) ([]*github.PullRequestInfo, error) { + prNumber := 12345 + return []*github.PullRequestInfo{ + { + Number: prNumber, + HTMLURL: fmt.Sprintf("https://github.com/%s/%s/pull/%s", ref.RepoOwner, ref.RepoName, strconv.Itoa(prNumber)), + BranchName: "my-branch", + }, + }, nil +} + func (s svc) GetOrgMembership(ctx context.Context, user, org string) (*githubv3.Membership, error) { role := "member" return &githubv3.Membership{Role: &role}, nil diff --git a/backend/service/github/github.go b/backend/service/github/github.go index 71e9aef8bc..48be2222fe 100644 --- a/backend/service/github/github.go +++ b/backend/service/github/github.go @@ -103,6 +103,7 @@ type Client interface { GetRepository(ctx context.Context, ref *RemoteRef) (*Repository, error) GetOrganization(ctx context.Context, organization string) (*githubv3.Organization, error) ListOrganizations(ctx context.Context, user string) ([]*githubv3.Organization, error) + ListPullRequestsWithCommit(ctx context.Context, ref *RemoteRef, sha string, opts *githubv3.PullRequestListOptions) ([]*PullRequestInfo, error) GetOrgMembership(ctx context.Context, user, org string) (*githubv3.Membership, error) GetUser(ctx context.Context, username string) (*githubv3.User, error) } @@ -117,8 +118,9 @@ func (s *svc) CreateIssueComment(ctx context.Context, ref *RemoteRef, number int } type PullRequestInfo struct { - Number int - HTMLURL string + Number int + HTMLURL string + BranchName string } type svc struct { @@ -236,6 +238,24 @@ func (s *svc) CreatePullRequest(ctx context.Context, ref *RemoteRef, base, title }, nil } +func (s *svc) ListPullRequestsWithCommit(ctx context.Context, ref *RemoteRef, sha string, opts *githubv3.PullRequestListOptions) ([]*PullRequestInfo, error) { + respPRs, _, err := s.rest.PullRequests.ListPullRequestsWithCommit(ctx, ref.RepoOwner, ref.RepoName, sha, opts) + if err != nil { + return nil, err + } + + prInfos := make([]*PullRequestInfo, len(respPRs)) + for i, pr := range respPRs { + prInfos[i] = &PullRequestInfo{ + Number: pr.GetNumber(), + HTMLURL: pr.GetHTMLURL(), + BranchName: pr.GetHead().GetRef(), + } + } + + return prInfos, nil +} + type CreateBranchRequest struct { // The base for the new branch. Ref *RemoteRef diff --git a/backend/service/github/github_test.go b/backend/service/github/github_test.go index f74bf922f3..a0d0893208 100644 --- a/backend/service/github/github_test.go +++ b/backend/service/github/github_test.go @@ -6,6 +6,7 @@ import ( "fmt" "io/ioutil" "net/http" + "strconv" "testing" "time" @@ -23,6 +24,10 @@ var ( timestamp = time.Unix(1569010072, 0) ) +func intPtr(i int) *int { + return &i +} + type getfileMock struct { v4client @@ -672,3 +677,115 @@ func TestGetRepository(t *testing.T) { }) } } + +type mockPullRequests struct { + generalError bool + + actualNumber int + actualHTMLURL string + actualBranchName string +} + +// Dummy mock of Create API so mockPullRequests implements v3pullrequests +func (m *mockPullRequests) Create(ctx context.Context, owner string, repo string, pull *githubv3.NewPullRequest) (*githubv3.PullRequest, *githubv3.Response, error) { + return &githubv3.PullRequest{}, &githubv3.Response{}, nil +} + +// Mock of ListPullRequestsWithCommit API +func (m *mockPullRequests) ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *githubv3.PullRequestListOptions) ([]*githubv3.PullRequest, *githubv3.Response, error) { + if m.generalError { + return nil, nil, errors.New(problem) + } + + m.actualNumber = 1347 + m.actualHTMLURL = fmt.Sprintf("https://github.com/%s/%s/pull/%s", owner, repo, strconv.Itoa(m.actualNumber)) + m.actualBranchName = "my-branch" + + return []*githubv3.PullRequest{ + { + Number: intPtr(m.actualNumber), + State: strPtr(opts.State), + HTMLURL: strPtr(m.actualHTMLURL), + Head: &githubv3.PullRequestBranch{ + Ref: strPtr(m.actualBranchName), + SHA: strPtr(sha), + Repo: &githubv3.Repository{ + Name: strPtr(repo), + }, + User: &githubv3.User{ + Login: strPtr("octocat"), + }, + }, + }, + }, nil, nil +} + +var listPullRequestsWithCommitTests = []struct { + name string + errorText string + mockPullReq *mockPullRequests + repoOwner string + repoName string + ref string + sha string + opts *githubv3.PullRequestListOptions +}{ + { + name: "happy path", + mockPullReq: &mockPullRequests{}, + repoOwner: "my-org", + repoName: "my-repo", + ref: "my-branch", + sha: "asdf12345", + opts: &githubv3.PullRequestListOptions{ + // Possible values for State: "open", "closed", "all". Default is "open", manually setting it to "all". + State: "all", + }, + }, + { + name: "v3 client error", + mockPullReq: &mockPullRequests{generalError: true}, + errorText: "we've had a problem", + repoOwner: "my-org", + repoName: "my-repo", + ref: "my-branch", + sha: "asdf12345", + opts: &githubv3.PullRequestListOptions{ + State: "all", + }, + }, +} + +func TestListPullRequestsWithCommit(t *testing.T) { + for _, tt := range listPullRequestsWithCommitTests { + tt := tt + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + s := &svc{rest: v3client{ + PullRequests: tt.mockPullReq, + }} + + resp, err := s.ListPullRequestsWithCommit( + context.Background(), + &RemoteRef{ + RepoOwner: tt.repoOwner, + RepoName: tt.repoName, + Ref: tt.ref, + }, + tt.sha, + tt.opts) + + if tt.errorText != "" { + assert.Error(t, err) + assert.Contains(t, err.Error(), tt.errorText) + } else { + assert.NoError(t, err) + assert.Equal(t, 1, len(resp)) + assert.Equal(t, tt.mockPullReq.actualNumber, resp[0].Number) + assert.Equal(t, tt.mockPullReq.actualHTMLURL, resp[0].HTMLURL) + assert.Equal(t, tt.mockPullReq.actualBranchName, resp[0].BranchName) + } + }) + } +} diff --git a/backend/service/github/iface.go b/backend/service/github/iface.go index a87594ff81..41cc092c44 100644 --- a/backend/service/github/iface.go +++ b/backend/service/github/iface.go @@ -33,6 +33,8 @@ type v3repositories interface { type v3pullrequests interface { // Create a new pull request on the specified repository. Create(ctx context.Context, owner string, repo string, pull *githubv3.NewPullRequest) (*githubv3.PullRequest, *githubv3.Response, error) + // ListPullRequestsWithCommit returns pull requests associated with a commit SHA. + ListPullRequestsWithCommit(ctx context.Context, owner, repo, sha string, opts *githubv3.PullRequestListOptions) ([]*githubv3.PullRequest, *githubv3.Response, error) } type v4client interface {