diff --git a/cmd/pr-size-labeler-action/main.go b/cmd/pr-size-labeler-action/main.go index c9b8d32..6271731 100644 --- a/cmd/pr-size-labeler-action/main.go +++ b/cmd/pr-size-labeler-action/main.go @@ -41,7 +41,6 @@ func (EnvArgs) Version() string { // Constants for default configuration and event names. const ( DefaultConfigPath = ".github/pull-request-size.yml" - EventPullRequest = "pull_request" ParamNameFiles = "files" ParamNameDiff = "diff" ) @@ -116,7 +115,7 @@ func main() { var args EnvArgs arg.MustParse(&args) - if !isValidEvent(args.EventName) || !isValidRepoFormat(args.RepoName) { + if !isValidGitHubEventType(args.EventName) || !isValidRepoFormat(args.RepoName) { return } @@ -138,13 +137,19 @@ func main() { prProcessor.ProcessPullRequest() } -// isValidEvent checks if the event name is a valid pull request event. -func isValidEvent(eventName string) bool { - if eventName != EventPullRequest { - fmt.Println("Event is not a pull request, doing nothing") - return false +// isValidGitHubEventType checks if the event name is a valid pull request event. +func isValidGitHubEventType(eventName string) bool { + allowedEvents := map[string]bool{ + "pull_request": true, + "pull_request_target": true, } - return true + + if allowedEvents[strings.ToLower(eventName)] { + return true + } + + fmt.Println("Event is not a valid pull request event, doing nothing") + return false } // isValidRepoFormat checks if the repository name follows the 'owner/repository' format. diff --git a/cmd/pr-size-labeler-action/main_test.go b/cmd/pr-size-labeler-action/main_test.go index 53a55f7..4882ace 100644 --- a/cmd/pr-size-labeler-action/main_test.go +++ b/cmd/pr-size-labeler-action/main_test.go @@ -75,6 +75,33 @@ func TestCalculateSizeAndDiff(t *testing.T) { } } +func TestIsValidGitHubEventType(t *testing.T) { + tests := []struct { + name string + eventName string + want bool + }{ + {"Valid Event pull_request", "pull_request", true}, + {"Valid Event pull_request_target", "pull_request_target", true}, + {"Invalid Event empty", "", false}, + {"Invalid Event random string", "random_event", false}, + {"Invalid Event issue", "issue", false}, + {"Invalid Event commit", "commit", false}, + {"Invalid Event push", "push", false}, + {"Invalid Event merge", "merge", false}, + {"Invalid Event null", "null", false}, + {"Invalid Event pull_request_closed", "pull_request_closed", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + if got := isValidGitHubEventType(tt.eventName); got != tt.want { + t.Errorf("isValidGitHubEventType(%v) = %v, want %v", tt.eventName, got, tt.want) + } + }) + } +} + func TestGetSize(t *testing.T) { // Define configuration entries for clarity xsConfig := ConfigEntry{"xs", 10, 1, []string{"size/xs"}}