diff --git a/.github/scripts/backport-pending.py b/.github/scripts/backport-pending.py deleted file mode 100644 index f2e3e71df..000000000 --- a/.github/scripts/backport-pending.py +++ /dev/null @@ -1,128 +0,0 @@ -import json -import os -import re -import sys -import urllib.error -import urllib.request -from dataclasses import dataclass -from typing import List - -VERSION_LABEL_RE = re.compile(r"^v\d{1,2}$|(^v\d{1,2}\.\d{1,2}$)") -PENDING_LABEL = "backport pending" -PENDING_LABEL_COLOR = "fff2bf" - - -@dataclass -class PRInfo: - number: int - labels: List[str] - - -def load_event() -> dict: - path = os.environ.get("GITHUB_EVENT_PATH") - if not path or not os.path.exists(path): - print("::warning::GITHUB_EVENT_PATH not set or file missing; nothing to do", file=sys.stderr) - return {} - with open(path, "r", encoding="utf-8") as f: - return json.load(f) - - -def extract_pr(event: dict) -> PRInfo | None: - pr = event.get("pull_request") - if not pr: - return None - labels = [lbl.get("name", "") for lbl in pr.get("labels", [])] - return PRInfo(number=pr["number"], labels=labels) - - -def needs_pending_label(info: PRInfo) -> bool: - has_version_label = any(VERSION_LABEL_RE.match(l) for l in info.labels) - has_pending = PENDING_LABEL in info.labels - return not (has_version_label and has_pending) - - -def add_label(pr_number: int, label: str) -> None: - repo = os.environ.get("GITHUB_REPOSITORY") - token = os.environ.get("BACKPORT_TOKEN") - if not repo or not token: - print("::error::Missing GITHUB_REPOSITORY or BACKPORT_TOKEN", file=sys.stderr) - sys.exit(1) - owner, repo_name = repo.split("/", 1) - # First ensure the label exists (create or update color/description) - ensure_label(owner, repo_name, token) - url = f"https://api.github.com/repos/{owner}/{repo_name}/issues/{pr_number}/labels" - body = json.dumps({"labels": [label]}).encode() - # POST adds label(s) keeping old ones - req = urllib.request.Request(url, data=body, method="POST") - req.add_header("Authorization", f"Bearer {token}") - req.add_header("Accept", "application/vnd.github+json") - try: - with urllib.request.urlopen(req) as resp: - if resp.status not in (200, 201): - print(f"::error::Failed to add label: HTTP {resp.status}", file=sys.stderr) - sys.exit(1) - print(f"Added label '{label}' to PR #{pr_number}") - except urllib.error.HTTPError as e: - print(f"::error::HTTP error adding label: {e.code} {e.reason}", file=sys.stderr) - sys.exit(1) - except Exception as e: - print(f"::error::Unexpected error adding label: {e}", file=sys.stderr) - sys.exit(1) - - -def ensure_label(owner: str, repo_name: str, token: str) -> None: - """Create the Backport Pending label if it does not already exist.""" - label_api = f"https://api.github.com/repos/{owner}/{repo_name}/labels/{PENDING_LABEL.replace(' ', '%20')}" - get_req = urllib.request.Request(label_api, method="GET") - get_req.add_header("Authorization", f"Bearer {token}") - get_req.add_header("Accept", "application/vnd.github+json") - try: - with urllib.request.urlopen(get_req) as resp: - if resp.status == 200: - return - except urllib.error.HTTPError as e: - print(f"::warning::Failed to check label existence ({e.code})") - return - create_api = f"https://api.github.com/repos/{owner}/{repo_name}/labels" - body = json.dumps({"name": PENDING_LABEL, "color": PENDING_LABEL_COLOR}).encode() - req = urllib.request.Request(create_api, data=body, method="POST") - req.add_header("Authorization", f"Bearer {token}") - req.add_header("Accept", "application/vnd.github+json") - try: - with urllib.request.urlopen(req) as resp: - if resp.status not in (200, 201): - print(f"::warning::Failed to create label (status {resp.status})") - except Exception as e: - print(f"::warning::Error creating label: {e}") - - -""" -Label a PR with 'Backport pending' if it has no version label. - -Expected environment: - GITHUB_EVENT_PATH: Path to the event JSON (GitHub sets this automatically) - GITHUB_REPOSITORY: owner/repo - BACKPORT_TOKEN: token with repo:issues scope (use BACKPORT_TOKEN or a PAT) - -This script is idempotent: if the PR already has a version label (vX.Y) or already -has the 'Backport Pending' label, it exits without error. -""" - - -def main() -> int: - event = load_event() - if not event: - return 0 - info = extract_pr(event) - if not info: - print("No pull_request object in event; skipping") - return 0 - if needs_pending_label(info): - add_label(info.number, PENDING_LABEL) - else: - print("No label needed (either PR has version label or already pending)") - return 0 - - -if __name__ == "__main__": - raise SystemExit(main()) diff --git a/.github/workflows/backport.yml b/.github/workflows/backport.action.yml similarity index 100% rename from .github/workflows/backport.yml rename to .github/workflows/backport.action.yml diff --git a/.github/workflows/backport.reminder.yml b/.github/workflows/backport.reminder.yml new file mode 100644 index 000000000..cf1634480 --- /dev/null +++ b/.github/workflows/backport.reminder.yml @@ -0,0 +1,42 @@ +name: Backport reminder + +on: + pull_request_target: + branches: [master] + types: [closed] + schedule: + - cron: '0 6 * * *' # Every day at 06:00 UTC + workflow_dispatch: + inputs: + lookback_days: + description: 'How many days back to search merged PRs' + required: false + default: '7' + pending_label_age_days: + description: 'Minimum age in days before reminding' + required: false + default: '14' + +env: + BACKPORT_TOKEN: ${{ secrets.BACKPORT_TOKEN }} + +jobs: + reminder: + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: write + steps: + - name: Checkout + uses: actions/checkout@v4 + - name: Add Backport pending label (single PR) + if: github.event_name == 'pull_request_target' + run: | + python github_ci_tools/scripts/backport.py label --pr-mode + python github_ci_tools/scripts/backport.py remind --pr-mode --pending-reminder-age-days ${{ github.event.inputs.pending_label_age_days }} + - name: Add Backport pending label (bulk) + if: github.event_name != 'pull_request_target' + run: | + python github_ci_tools/scripts/backport.py label --lookback-days ${{ github.event.inputs.lookback_days }} + python github_ci_tools/scripts/backport.py remind --lookback-days ${{ github.event.inputs.lookback_days }} --pending-reminder-age-days ${{ github.event.inputs.pending_label_age_days }} \ No newline at end of file diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index fb240820d..bfc8931a9 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -76,12 +76,12 @@ jobs: steps: - uses: actions/checkout@v4 - name: Parse repo and create filters.yml - run: python3 .github/scripts/track-filter.py + run: python3 github_ci_tools/scripts/track-filter.py - uses: dorny/paths-filter@de90cc6fb38fc0963ad72b210f1f284cd68cea36 #v3.0.2 id: changes with: token: ${{ secrets.GITHUB_TOKEN }} - filters: .github/filters.yml + filters: github_ci_tools/filters.yml - name: Collect changed tracks and calculate --track-filter argument id: track-filter run: | diff --git a/.gitignore b/.gitignore index 73633544c..7d9b795e6 100644 --- a/.gitignore +++ b/.gitignore @@ -99,6 +99,9 @@ target/ #Pickles *.pk +# direnv +.envrc + # pyenv .python-version diff --git a/github_ci_tools/__init__.py b/github_ci_tools/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/github_ci_tools/scripts/__init__.py b/github_ci_tools/scripts/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/github_ci_tools/scripts/backport.py b/github_ci_tools/scripts/backport.py new file mode 100755 index 000000000..f30480052 --- /dev/null +++ b/github_ci_tools/scripts/backport.py @@ -0,0 +1,521 @@ +#!/usr/bin/env python3 + +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Backport CLI + +- Apply 'backport pending' label to merged PRs that require backport. +- Post reminder comments on such PRs that have a 'backport pending' label +but a version label (e.g. vX.Y) has not been added yet. +- Omits PRs labeled 'backport'. + +Usage: backport.py [options] [flags] + +Options: + --repo owner/repo + --pr-mode Single PR mode (use event payload); Handle PR through GITHUB_EVENT_PATH. + -v, --verbose Increase verbosity (can be repeated: -vv) + -q, --quiet Decrease verbosity (can be repeated: -qq) + --dry-run Simulate actions without modifying GitHub state + +Commands: + label Add 'backport pending' label to merged PRs lacking version/backport labels + remind Post reminders on merged PRs still pending backport + +Flags: + --lookback-days N Days to scan in bulk + --pending-reminder-age-days M Days between reminders + --remove Remove 'backport pending' label + +Quick usage: + backport.py label --pr-mode + backport.py --repo owner/name label --lookback-days 7 + backport.py --repo owner/name remind --lookback-days 30 --pending-reminder-age-days 14 + backport.py --repo owner/name --dry-run -vv label --lookback-days 30 + +Logic: + Add label when: no version label (regex vX(.Y)), no pending or 'backport' label. + Remind when: pending label present AND (no previous reminder OR last reminder older than M days). + Marker: + +Exit codes: 0 success / 1 error. +""" + +import argparse +import datetime as dt +import itertools +import json +import logging +import os +import re +import sys +import urllib.error +import urllib.request +from collections.abc import Iterable +from dataclasses import dataclass, field +from typing import Any +from urllib.parse import urlencode + +LOG = logging.getLogger(__name__) + +ISO_FORMAT = "%Y-%m-%dT%H:%M:%SZ" +VERSION_LABEL_RE = re.compile(r"^v\d{1,2}(?:\.\d{1,2})?$") +BACKPORT_LABEL = "backport" +PENDING_LABEL = "backport pending" +PENDING_LABEL_COLOR = "fff2bf" +COULD_NOT_CREATE_LABEL_WARNING = "Could not create label" +GITHUB_API = "https://api.github.com" +COMMENT_MARKER_BASE = "" # static for detection +REMINDER_BODY = ( + "A backport is pending for this PR. Please add all required `vX.Y` version labels.\n\n" + " - If it is intended for the current Elasticsearch release version, apply the corresponding version label.\n" + " - If it also supports past released versions, add those labels too.\n" + " - If it only targets a future version, wait until that version label exists and then add it.\n" + " (Each rally-tracks version label is created during the feature freeze of a new Elasticsearch branch).\n\n" + "Backporting entails: \n" + " 1. Ensure the correct version labels exist in this PR.\n" + " 2. Ensure backport PRs have `backport` label and are passing tests.\n" + " 3. Merge backport PRs (you can approve yourself and enable auto-merge).\n" + " 4. Remove `backport pending` label from this PR once all backport PRs are merged.\n\n" + "Thank you!" +) + + +@dataclass +class BackportConfig: + token: str | None = None + repo: str | None = None + dry_run: bool = False + log_level: int = logging.INFO + command: str | None = None + verbose: int = 0 + quiet: int = 0 + + +CONFIG = BackportConfig( + token=os.environ.get("BACKPORT_TOKEN"), + repo=os.environ.get("GITHUB_REPOSITORY"), +) + + +# ----------------------------- GH Helpers ----------------------------- +def gh_request(method: str = "GET", path: str = "", body: dict[str, Any] | None = None, params: dict[str, str] | None = None) -> Any: + if params: + path = f"{path}?{urlencode(params)}" + url = f"{GITHUB_API}/{path}" + data = None + if body is not None: + data = json.dumps(body).encode() + # In dry-run, skip mutating requests (anything not GET) and just log. + if is_dry_run(): + LOG.debug(f"Would {method} {url} body={json.dumps(body)}") + if method.upper() != "GET": + return {} + req = urllib.request.Request(url, data=data, method=method) + req.add_header("Authorization", f"Bearer {CONFIG.token}") + req.add_header("Accept", "application/vnd.github+json") + try: + with urllib.request.urlopen(req) as resp: + charset = resp.headers.get_content_charset() or "utf-8" + txt = resp.read().decode(charset) + LOG.debug(f"Response {resp.status} {method}") + if resp.status >= 300: + raise RuntimeError(f"HTTP {resp.status}: {txt}") + return json.loads(txt) if txt.strip() else {} + except urllib.error.HTTPError as e: + err = e.read().decode() + raise RuntimeError(f"HTTP {e.code} {e.reason} {err}") from e + + +@dataclass +class PRInfo: + number: int = -1 + labels: list[str] = field(default_factory=list) + + @classmethod + def from_dict(cls, pr: dict[str, Any]) -> "PRInfo": + number = int(pr.get("number") or pr.get("url", "").rstrip("/").strip().rsplit("/", 1) or -1) + labels = [lbl.get("name", "") for lbl in pr.get("labels", [])] + if number == -1 and not labels: + raise ValueError("...") + return cls(number, labels) + + +# ----------------------------- PR Extraction (single or bulk) ----------------------------- +def load_event() -> dict: + """Load the GitHub event payload from GITHUB_EVENT_PATH for single PR mode. + + Returns an empty dict if the path is missing to allow callers to decide on fallback behavior. + """ + path = os.environ.get("GITHUB_EVENT_PATH", "").strip() + if not path: + raise FileNotFoundError("GITHUB_EVENT_PATH environment variable is empty") + if not os.path.isfile(path): + raise FileNotFoundError(f"File not found: {path}") + with open(path, "r", encoding="utf-8") as f: + data = json.load(f) + if not isinstance(data, dict): + raise TypeError(f"Event data is a {type(data)}, want a dict.") + return data + + +def list_prs(q_filter: str, since: dt.datetime) -> Iterable[dict[str, Any]]: + """Query the GH API with a filter to iterate over PRs updated after a given timestamp.""" + q_date = since.strftime("%Y-%m-%d") + q = f"{q_filter} updated:>={q_date}" + LOG.debug(f"Fetch PRs with filter '{q}'") + params = {"q": f"{q}", "per_page": "100"} + for page in itertools.count(1): + params["page"] = str(page) + results = gh_request(path="search/issues", params=params) + items = results.get("items", []) + yield from items + if len(items) < 100: + break + + +# ----------------------------- Label Logic ----------------------------- +def add_repository_label(repository: str | None, name: str, color: str): + if repository is None: + raise RuntimeError("Cannot add label: repository is None") + ( + LOG.info(f"Would create label '{name}' with color '{color}' in repo '{repository}'") + if is_dry_run() + else LOG.info(f"Creating label '{name}' with color '{color}' in repo '{repository}'") + ) + gh_request(method="POST", path=f"repos/{repository}/labels", body={"name": name, "color": color}) + + +def repo_needs_pending_label(repo_labels: list[str]) -> bool: + LOG.debug(f"{PENDING_LABEL} in repo labels: {repo_labels} -> {PENDING_LABEL in repo_labels}") + return PENDING_LABEL not in repo_labels + + +def ensure_backport_pending_label() -> None: + """If the exact PENDING_LABEL string does not appear at least once, we create it.""" + try: + existing = gh_request(path=f"repos/{CONFIG.repo}/labels", params={"per_page": "100"}) + except Exception as e: + existing = [] + names = [lbl.get("name", "") for lbl in existing] + if not repo_needs_pending_label(names): + return + try: + add_repository_label(repository=CONFIG.repo, name=PENDING_LABEL, color=PENDING_LABEL_COLOR) + except Exception as e: + LOG.warning(f"{COULD_NOT_CREATE_LABEL_WARNING}: {e}") + + +def pr_needs_pending_label(info: PRInfo) -> bool: + has_version_label = any(VERSION_LABEL_RE.match(label) for label in info.labels) + return PENDING_LABEL not in info.labels and BACKPORT_LABEL not in info.labels and not has_version_label + + +def add_pull_request_label(pr_number: int, label: str) -> None: + LOG.info(f"Would add label '{label}' to PR #{pr_number}") if is_dry_run() else LOG.info(f"Adding label '{label}' to PR #{pr_number}") + gh_request(method="POST", path=f"repos/{CONFIG.repo}/issues/{pr_number}/labels", body={"labels": [label]}) + + +def remove_pull_request_label(pr_number: int, label: str) -> None: + ( + LOG.info(f"Would remove label '{label}' from PR #{pr_number}") + if is_dry_run() + else LOG.info(f"Removing label '{label}' from PR #{pr_number}") + ) + gh_request(method="DELETE", path=f"repos/{CONFIG.repo}/issues/{pr_number}/labels", body={"labels": [label]}) + + +def run_label(prefetched_prs: list[dict[str, Any]], remove: bool) -> int: + """Apply label logic to prefetched merged PRs (single for pull_request_target or bulk).""" + if not prefetched_prs: + raise RuntimeError("No PRs prefetched for labeling") + # Ensure repository has pending label definition before any per-PR action. + try: + ensure_backport_pending_label() + except Exception as e: + raise RuntimeError(f"Cannot ensure that backport pending label exists in repo: {e}") + errors = 0 + for pr in prefetched_prs: + try: + if not pr: + continue + info = PRInfo.from_dict(pr) + if remove: + remove_pull_request_label(info.number, PENDING_LABEL) + elif pr_needs_pending_label(info): + add_pull_request_label(info.number, PENDING_LABEL) + else: + LOG.debug(f"PR #{info.number}: No label action needed") + except Exception as e: + LOG.error(f"Label error for PR #{pr.get('number','unknown')}: {e}") + errors += 1 + return errors + + +# ----------------------------- Reminder Logic ----------------------------- +def get_issue_comments(number: int) -> list[dict[str, Any]]: + comments: list[dict[str, Any]] = [] + page = 1 + repo = CONFIG.repo + while True: + data = gh_request(path=f"repos/{repo}/issues/{number}/comments", params={"per_page": "100", "page": str(page)}) + if not data: + break + comments.extend(data) + # We are using a page size of 100. If we get less, we are done. + if len(data) < 100: + break + page += 1 + return comments + + +def add_comment(number: int, body: str) -> None: + if is_dry_run(): + LOG.info(f"Would add comment to PR #{number}:\n{body}") + return + gh_request(method="POST", path=f"repos/{CONFIG.repo}/issues/{number}/comments", body={"body": body}) + + +def last_reminder_time(comments: list[dict[str, Any]], marker: str) -> dt.datetime | None: + def comment_ts(c: dict[str, Any]) -> dt.datetime: + raw_timestamp = c.get("created_at") or c.get("updated_at") + if not raw_timestamp: + raise RuntimeError("Comment missing timestamp fields") + return dt.datetime.strptime(raw_timestamp, ISO_FORMAT).replace(tzinfo=dt.timezone.utc) + + for c in sorted(comments, key=comment_ts, reverse=True): + body = c.get("body") or "" + if marker in body: + return comment_ts(c) + return None + + +def pr_needs_reminder(info: PRInfo, threshold: dt.datetime) -> bool: + if not any(label == PENDING_LABEL for label in info.labels): + return False + comments = get_issue_comments(info.number) + prev_time = last_reminder_time(comments, COMMENT_MARKER_BASE) + if prev_time is None: + return True + return prev_time < threshold + + +def delete_reminders(info: PRInfo) -> None: + comments = get_issue_comments(info.number) + repo = CONFIG.repo + for c in comments: + body = c.get("body") or "" + if COMMENT_MARKER_BASE in body: + comment_id = c.get("id") + if comment_id is None: + LOG.warning(f"Cannot delete comment on PR #{info.number}: missing comment ID") + continue + if is_dry_run(): + LOG.info(f"Would delete comment ID {comment_id} on PR #{info.number}") + continue + gh_request(method="DELETE", path=f"repos/{repo}/issues/comments/{comment_id}") + LOG.info(f"Deleted comment ID {comment_id} on PR #{info.number}") + + +def run_remind(prefetched_prs: list[dict[str, Any]], pending_reminder_age_days: int, lookback_days: int) -> int: + """Post reminders using prefetched merged PR list.""" + if not prefetched_prs: + raise RuntimeError("No PRs prefetched for reminding") + now = dt.datetime.now(dt.timezone.utc) + threshold = now - dt.timedelta(days=pending_reminder_age_days) + errors = 0 + for pr in prefetched_prs: + try: + if not pr: + continue + info = PRInfo.from_dict(pr) + if pr_needs_reminder(info, threshold): + author = pr.get("user", {}).get("login", "PR author") + delete_reminders(info) + add_comment(info.number, f"{COMMENT_MARKER_BASE}\n@{author}\n{REMINDER_BODY}") + LOG.info(f"PR #{info.number}: initial reminder posted") + else: + LOG.info(f"PR #{info.number}: cooling period not elapsed)") + except Exception as ex: + LOG.error(f"Remind error for PR #{pr.get('number', '?')}: {ex}") + errors += 1 + continue + return errors + + +# ----------------------------- CLI ----------------------------- +def is_dry_run() -> bool: + return CONFIG.dry_run + + +def require_mandatory_vars() -> None: + """Validate critical environment / CLI inputs using CONFIG.""" + if not CONFIG.token: + raise RuntimeError("Missing BACKPORT_TOKEN from environment.") + repo = CONFIG.repo + if not repo or not re.match(r"^[^/]+/[^/]+$", str(repo)): + raise RuntimeError("Missing or invalid GITHUB_REPOSITORY. Either set it or pass --repo (owner/repo)") + + +def configure(args: argparse.Namespace) -> None: + """Populate CONFIG, initialize logging, and validate required inputs. + + This centralizes setup so other entry points (tests, future subcommands) + can reuse consistent initialization semantics. + """ + CONFIG.dry_run = args.dry_run + CONFIG.verbose = args.verbose + CONFIG.quiet = args.quiet + CONFIG.log_level = (CONFIG.quiet - CONFIG.verbose) * (logging.INFO - logging.DEBUG) + logging.INFO + CONFIG.command = args.command + CONFIG.token = os.environ.get("BACKPORT_TOKEN") + CONFIG.repo = args.repo if args.repo is not None else os.environ.get("GITHUB_REPOSITORY") + logging.basicConfig(level=CONFIG.log_level, format="%(asctime)s %(levelname)s %(name)s %(message)s") + require_mandatory_vars() + + +def prefetch_prs(pr_mode: bool, lookback_days: int) -> list[dict[str, Any]]: + if pr_mode: + event = load_event() + if event: + pr_data = event.get("pull_request") + else: + raise RuntimeError("Failed to load event data") + if not pr_data: + raise RuntimeError(f"No pull_request data in event: {event}") + # Ensure PR is merged. + merged_flag = pr_data.get("merged") + merged_at = pr_data.get("merged_at") + if not merged_flag and not merged_at: + raise RuntimeError(f"PR #{pr_data.get('number','?')} not merged yet; skipping.") + try: + merged_dt = dt.datetime.strptime(merged_at, ISO_FORMAT).replace(tzinfo=dt.timezone.utc) + except ValueError as e: + raise RuntimeError(f"Invalid merged_at format: {merged_at}") from e + now = dt.datetime.now(dt.timezone.utc) + age_days = (now - merged_dt).days + if age_days >= lookback_days + 1: + LOG.info( + f"PR #{pr_data.get('number','?')} merged_at {merged_at} age={age_days}d " + f"exceeds lookback_days={lookback_days}; filtering out." + ) + return [] + return [pr_data] + now = dt.datetime.now(dt.timezone.utc) + since = now - dt.timedelta(days=lookback_days) + repo = CONFIG.repo + # Note that we rely on is:merged to filter out unmerged PRs. + return list(list_prs(f"repo:{repo} is:pr is:merged", since)) + + +def parse_args() -> argparse.Namespace: + try: + parser = argparse.ArgumentParser( + description="Backport utilities", + epilog="""\nExamples:\n backport.py label --pr-mode\n backport.py label --lookback-days 7\n backport.py remind --lookback-days 30 --pending-reminder-age-days 14\n backport.py --dry-run -vv label --lookback-days 30\n\nSingle PR mode (--pr-mode) reads the pull_request payload from GITHUB_EVENT_PATH.\nBulk mode searches merged PRs updated within --lookback-days.\n""", + formatter_class=argparse.RawDescriptionHelpFormatter, + ) + parser.add_argument( + "--repo", + help="Target repository in owner/repo form (overrides GITHUB_REPOSITORY env)", + required=False, + default=None, + ) + parser.add_argument( + "--pr-mode", + action="store_true", + help="Single PR mode (use GITHUB_EVENT_PATH pull_request payload). Default: bulk scan via search API", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Simulate actions without modifying GitHub state", + ) + parser.add_argument( + "-v", + "--verbose", + action="count", + default=0, + help="Increase verbosity (can be used multiple times, e.g., -vv for more verbose)", + ) + parser.add_argument( + "-q", + "--quiet", + action="count", + default=0, + help="Decrease verbosity (can be used multiple times)", + ) + sub = parser.add_subparsers(dest="command", required=True) + + p_label = sub.add_parser( + "label", help="Add backport pending label to merged PRs lacking 'backport', 'backport pending' or version label" + ) + p_label.add_argument( + "--lookback-days", + type=int, + required=False, + default=7, + help="Days to look back (default: 7). Ignored in --pr-mode", + ) + p_label.add_argument( + "--remove", + action="store_true", + required=False, + default=False, + help="Removes backport pending label", + ) + + p_remind = sub.add_parser("remind", help="Post reminders on merged PRs still pending backport") + p_remind.add_argument( + "--lookback-days", + type=int, + required=False, + default=7, + help="Days to look back (default: 7). Ignored in --pr-mode", + ) + p_remind.add_argument( + "--pending-reminder-age-days", + type=int, + required=False, + default=14, + help="Days between reminders for the same PR (default: 14). Adds initial reminder if none posted yet.", + ) + + except Exception: + raise RuntimeError("Command parsing failed") + return parser.parse_args() + + +def main(): + args = parse_args() + configure(args) + + LOG.debug(f"Parsed arguments: {args}") + prefetched = prefetch_prs(args.pr_mode, args.lookback_days) + LOG.debug(f"Prefetched {len(prefetched)} PRs for command '{args.command}': {[pr.get('number') for pr in prefetched]}") + match args.command: + case "label": + return run_label(prefetched, args.remove) + case "remind": + return run_remind(prefetched, args.pending_reminder_age_days, args.lookback_days) + case _: + raise NotImplementedError(f"Unknown command {args.command}") + + +if __name__ == "__main__": + main() diff --git a/.github/scripts/track-filter.py b/github_ci_tools/scripts/track-filter.py similarity index 66% rename from .github/scripts/track-filter.py rename to github_ci_tools/scripts/track-filter.py index 92d548941..0297055e2 100644 --- a/.github/scripts/track-filter.py +++ b/github_ci_tools/scripts/track-filter.py @@ -5,11 +5,11 @@ filters = {} # static file paths should be a comma-separated list of files or directories (omitting the trailing '/') -static_paths = os.environ.get("RUN_FULL_CI_WHEN_CHANGED", []) +static_paths: list[str] = os.environ.get("RUN_FULL_CI_WHEN_CHANGED", "").split(",") # Statically include some files that should always trigger a full CI run if static_paths: - filters["full_ci"] = [f"{path}/**" if os.path.isdir(path.strip()) else path.strip() for path in static_paths.split(",")] + filters["full_ci"] = [f"{path}/**" if os.path.isdir(path.strip()) else path.strip() for path in static_paths] # Dynamically create filters for each track (top-level subdirectory) in the repo for entry in os.listdir("."): @@ -17,6 +17,6 @@ filters[entry] = [f"{entry}/**"] -with open(".github/filters.yml", "w") as f: +with open("github_ci_tools/filters.yml", "w") as f: yaml.dump(filters, f, default_flow_style=False) -print(f"Created .github/filters.yml with {len(filters)} track(s): {', '.join(filters.keys())}") +print(f"Created github_ci_tools/filters.yml with {len(filters)} track(s): {', '.join(filters.keys())}") diff --git a/github_ci_tools/tests/__init__.py b/github_ci_tools/tests/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/github_ci_tools/tests/conftest.py b/github_ci_tools/tests/conftest.py new file mode 100644 index 000000000..c1aa542ea --- /dev/null +++ b/github_ci_tools/tests/conftest.py @@ -0,0 +1,279 @@ +"""Pytest configuration and fixtures for testing the backport CLI. + +Provides: + - Dynamic loading of the `backport.py` module (so no package __init__ files required). + - An injectable GitHub API mock (`gh_mock`) that records calls and returns + predefined responses or raises exceptions. + - Helper fixtures for creating synthetic PR payloads and reminder comments. + - A convenience fixture to run `configure()` with a minimal argparse.Namespace. + +Usage examples in tests: + + def test_needs_pending_label(backport_mod, pr_no_labels): + assert backport_mod.needs_pending_label(pr_no_labels) + + def test_label_api_called(backport_mod, gh_mock, pr_versioned): + gh_mock.add(f'repos/{TEST_REPO}/labels/backport%20pending', method='GET', response={}) # label exists + gh_mock.add(f'repos/{TEST_REPO}/issues/42/labels', method='POST', response={'ok': True}) + backport_mod.add_pull_request_label(42, backport_mod.PENDING_LABEL) + assert any(f'repos/{TEST_REPO}/issues/42/labels' in c['path'] for c in gh_mock.calls) + +Note: We treat paths exactly as provided to `gh_request` after query param expansion. +If you register a route with query parameters, include the full `path?query=..` string. +""" + +import datetime as dt +import fnmatch +import importlib.util +import os +import sys +from copy import deepcopy +from dataclasses import dataclass +from os.path import dirname, join +from pathlib import Path +from typing import Any +from urllib.parse import urlencode + +import pytest + +from github_ci_tools.scripts import backport +from github_ci_tools.tests.utils import NOW, TEST_REPO, convert_str_to_date + + +# ----------------------- Environment / Config ------------------------ +@pytest.fixture(autouse=True) +def set_env() -> None: + """Session-level mandatory environment variables prior to any module use. + + Can't use the function-scoped `monkeypatch` fixture from a session scope, so + we set values directly on os.environ here. + """ + os.environ["BACKPORT_TOKEN"] = "dummy-token" + os.environ["GITHUB_REPOSITORY"] = TEST_REPO + + +# --------------------------- Module Loader --------------------------- +@pytest.fixture(scope="function") +def backport_mod(monkeypatch) -> Any: + module = backport + fixed = convert_str_to_date(NOW) + + class FixedDateTime(dt.datetime): + @classmethod + def now(cls, tz=None): # noqa: D401 + # Always return an aware datetime. If tz is provided, adjust; otherwise keep UTC. + if tz is None: + return fixed + return fixed.astimezone(tz) + + monkeypatch.setattr(module.dt, "datetime", FixedDateTime) + return module + + +# --------------------------- GitHub Mock ----------------------------- +@dataclass +class MockRouteResponse: + status: int + json: dict[str, Any] | list[dict[str, Any]] + exception: BaseException | None = None + + +@dataclass +class MockCallRecord: + method: str + path: str + body: dict[str, Any] | None + params: dict[str, str] | None + + +class GitHubMock: + """Approximating GitHub REST semantics. + + Routes: Each route is a predefined (path, HTTP method) pair with an associated + response for simulation (JSON, status code, exception, object types). It models + individual GitHub REST API endpoints the code under test might call. Registering + routes lets tests declare exactly which API interactions are expected and which data + or error should be returned, without making real network calls. + + Calls list: Every time the mocked gh_request is invoked, the mock appends an entry + (path, method, body, params) to calls. Tests use this list to assert: + * That expected endpoints were hit (presence/order/count). + * That request bodies or query parameters match what the logic should send. + + In short: + - Routes define allowed interactions + - Calls record actual interactions. + - Assertion helper to verify expected vs actual behavior. + + Routes are matched by fully expanded path and uppercased method. + + Wildcard / glob support: + Register paths containing the literal sequence '...' (three dots). Each '...' + becomes a glob wildcard (*). + For example: '/search/issues?q=a_string...repo...merged...updated...end_string' + -> Glob pattern: '/search/issues?q=a_string*repo*merged*updated*end_string*' + -> Matches any path that starts with 'a_string', contains 'repo', 'merged', 'updated' + in that order, and ends with 'end_string'. + """ + + def __init__(self) -> None: + self._routes: dict[tuple[str, str], MockRouteResponse] = {} + self._glob_routes: list[tuple[str, str, str, MockRouteResponse]] = [] # (METHOD, original, glob_pattern, response) + self.calls: list[MockCallRecord] = [] + + # -------------- Registration ------------ + def add( + self, + method: str = "GET", + path: str = "/repos", + response: dict[str, Any] | list[dict[str, Any]] = {}, + status: int = 200, + exception: BaseException | None = None, + ) -> None: + """Register a route. + + Parameters: + path: API path exactly as gh_request would see + method: HTTP method + response: JSON-serializable object returned to caller + status: HTTP status code (>=400 will raise RuntimeError automatically) + exception: If provided, raised instead of using status/response + """ + m = method.upper() + route_resp = MockRouteResponse(status=status, json=response, exception=exception) + if "..." in path: + parts = path.split("...") + glob_pattern = "*".join(parts) + if not path.endswith("..."): + glob_pattern += "*" + self._glob_routes.append((m, path, glob_pattern, route_resp)) + else: + self._routes[(m, path)] = route_resp + + # -------------- Invocation -------------- + def __call__( + self, + method: str = "GET", + path: str = "repos", + body: dict[str, Any] | None = None, + params: dict[str, str] | None = None, + ) -> Any: + if params: + path = f"{path}?{urlencode(params)}" + path = f"/{path}" + key = (method.upper(), path) + + self.calls.append( + MockCallRecord( + method=method.upper(), + path=path, + body=deepcopy(body), + params=deepcopy(params), + ) + ) + if key not in self._routes: + # Register paths containing several literal sequences '...' which become glob wildcards (*). + route = None + if self._glob_routes: + for m, original, glob_pattern, resp in self._glob_routes: + if m != method.upper(): + continue + if fnmatch.fnmatchcase(path, glob_pattern): + route = resp + break + if route is None: + raise AssertionError( + f"Unexpected GitHub API call: {key}. Registered exact: {list(self._routes.keys())} glob: {[(method, orig) for method,orig,_,_ in self._glob_routes]}" + ) + else: + route = self._routes[key] + + if route.exception: + raise route.exception + if route.status >= 400: + raise RuntimeError(f"HTTP {route.status}: {route.json}") + return deepcopy(route.json) + + # -------------- Assertion --------------- + def assert_calls_in_order(self, *expected: tuple[str, str], strict: bool = True) -> None: + """ + Assert that the provided sequence of (HTTP_METHOD, full_path) tuples appears + in order. When strict=True (default), the recorded calls must match exactly + (same length and element-wise equality). When strict=False, the expected + sequence must appear as an ordered subsequence within the recorded calls. + """ + actual = [(c.method, c.path) for c in self.calls] + if not expected: + if actual: + raise AssertionError(f"Expected no GitHub calls, but saw {len(actual)}: {actual}") + return + + def glob_match(exp: tuple[str, str], act: tuple[str, str]) -> bool: + exp_m, exp_p = exp + act_m, act_p = act + if exp_m != act_m: + return False + if "..." in exp_p: + parts = exp_p.split("...") + glob_pattern = "*".join(parts) + if not exp_p.endswith("..."): + glob_pattern += "*" + return fnmatch.fnmatchcase(act_p, glob_pattern) + else: + return exp_p == act_p + + check = True + if strict: + # Build a small diff aid + lines = ["Strict order mismatch:"] + for i, (exp, act) in enumerate(zip(expected, actual)): + max_len = max(len(actual), len(expected)) + for i in range(max_len): + exp = expected[i] if i < len(expected) else ("", "") + act = actual[i] if i < len(actual) else ("", "") + if exp and act and (exp == act or glob_match(exp, act)): + marker = "OK" + else: + marker = "!!" + check = False + lines.append(f"[{i}] exp={exp} act={act} {marker}") + if not check: + raise AssertionError("\n".join(lines)) + return + + # Relaxed subsequence check + it = iter(actual) + for exp in expected: + if not any(expected == act or glob_match(exp, act) for act in it): + check = False + raise AssertionError(f"Expected {expected} not found in actual calls: {actual}") + + # -------------- Convenience -------------- + @property + def calls_list(self) -> list[tuple[str, str]]: + """Return the list of recorded calls as (method, path) tuples for convenience.""" + return [(c.method, c.path) for c in self.calls] + + +@pytest.fixture() +def gh_mock(backport_mod, monkeypatch: pytest.MonkeyPatch) -> GitHubMock: + """Fixture that patches `backport_mod.gh_request` with a controllable mock. + + Use `gh_mock.add(...)` to declare responses before invoking code under test. + + By replacing backport_mod.gh_request with GitHubMock, each test can: + - Declare exactly which endpoints should be called (via add()). + - Control payloads and error codes deterministically. + """ + mock = GitHubMock() + monkeypatch.setattr(backport_mod, "gh_request", mock) + return mock + + +# --------------------------- Event Payload --------------------------- +@pytest.fixture() +def event_file(tmp_path, monkeypatch) -> Path: + """Create a temporary GitHub event JSON.""" + path = tmp_path / "event.json" + monkeypatch.setenv("GITHUB_EVENT_PATH", str(path)) + return path diff --git a/github_ci_tools/tests/resources/__init__.py b/github_ci_tools/tests/resources/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/github_ci_tools/tests/resources/case_registry.py b/github_ci_tools/tests/resources/case_registry.py new file mode 100644 index 000000000..983381b5d --- /dev/null +++ b/github_ci_tools/tests/resources/case_registry.py @@ -0,0 +1,385 @@ +"""Minimal registry of pull request test cases. + +This module intentionally keeps ONLY: + - A single list `PR_CASES` containing all `PullRequestCase` objects. + - A helper `select_pull_requests(**filters)` that filters by attributes that + exist directly on `PullRequestCase` (e.g. number, merged, needs_pending, needs_reminder, labels, etc.). + +No classification metadata, wrapper dataclasses, or pattern axes are retained. +If higher-level categorization is needed, compose it at call sites. +""" + +from __future__ import annotations + +from dataclasses import asdict +from enum import Enum +from typing import Any + +from github_ci_tools.tests.resources.cases import PullRequestCase +from github_ci_tools.tests.utils import ( + COMMENTS, + COMMENTS_PER_PAGE, + LABELS, + NOW, + SEARCH_LABELS_PER_PAGE, + TEST_REPO, + GHRoute, + convert_str_to_date, + lookback_cutoff, +) + + +def _pr(**kwargs) -> PullRequestCase: + """Thin helper to create PullRequestCase with minimal kwargs.""" + return PullRequestCase(**kwargs) + + +# ----------------------- Static PR Cases ----------------------- +PR_CASES: list[PullRequestCase] = [ + _pr(number=101, merged_at="2025-10-23T12:00:00Z", needs_pending=True), + _pr( + number=102, + merged_at="2025-10-23T12:00:00Z", + needs_pending=True, + comments=[ + COMMENTS["recent_comment"], + COMMENTS["old_reminder"], + COMMENTS["old_comment"], + ], + ), + _pr( + number=103, + merged_at="2025-10-23T12:00:00Z", + labels=LABELS["versioned_typo"], + needs_pending=True, + comments=[ + COMMENTS["recent_reminder"], + COMMENTS["old_comment"], + ], + ), + _pr( + number=104, + merged_at="2025-10-02T12:00:00Z", + labels=LABELS["versioned"], + needs_reminder=True, + ), + _pr( + number=105, + merged_at="2025-10-02T12:00:00Z", + labels=LABELS["backport"], + comments=[ + COMMENTS["recent_comment"], + COMMENTS["recent_reminder"], + ], + ), + _pr( + number=106, + merged_at="2025-10-02T12:00:00Z", + labels=LABELS["versioned_pending"], + backport_pending_in_labels=True, + needs_reminder=True, + comments=[ + COMMENTS["recent_comment"], + ], + ), + _pr( + number=107, + merged_at="2025-10-02T12:00:00Z", + labels=LABELS["pending_typo"], + needs_pending=True, + comments=[ + COMMENTS["strange_new_comment"], + COMMENTS["old_reminder"], + ], + ), + _pr( + number=108, + merged_at="2025-10-02T12:00:00Z", + labels=LABELS["versioned_pending_typo"], + comments=[ + COMMENTS["recent_comment"], + COMMENTS["marker_in_old_comment_difficult"], + COMMENTS["marker_only_new_comment"], + COMMENTS["really_old_reminder"], + ], + ), + _pr( + number=109, + merged_at="2025-10-23T12:00:00Z", + labels=LABELS["pending_typo"], + needs_pending=True, + comments=[ + COMMENTS["recent_comment"], + COMMENTS["marker_in_old_comment_difficult"], + COMMENTS["really_old_reminder"], + ], + ), + # Unmerged PRs should be ignored no matter what their state is. + _pr(number=201, merged=False, needs_pending=True, needs_reminder=True), + _pr( + number=202, + merged=False, + labels=LABELS["versioned"], + comments=[COMMENTS["marker_in_text_of_new_comment"]], + ), + _pr( + number=203, + merged=False, + labels=LABELS["versioned_typo"], + needs_pending=True, + comments=[ + COMMENTS["really_old_reminder"], + ], + ), + _pr( + number=204, + merged=False, + labels=LABELS["versioned_pending_typo"], + comments=[ + COMMENTS["recent_comment"], + ], + ), + # Old merged PRs for lookback and reminder age testing + _pr(number=301, merged_at="2023-10-01T12:00:00Z", needs_pending=True), + _pr( + number=302, + merged_at="2023-10-01T12:00:00Z", + needs_pending=True, + comments=[ + COMMENTS["really_old_reminder"], + ], + ), + _pr( + number=303, + merged_at="2023-10-01T12:00:00Z", + needs_pending=True, + comments=[ + COMMENTS["old_reminder"], + COMMENTS["strange_new_comment"], + COMMENTS["really_old_reminder"], + ], + ), + _pr( + number=304, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["versioned"], + ), + _pr( + number=305, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["backport"], + comments=[ + COMMENTS["really_old_reminder"], + ], + ), + _pr( + number=306, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["backport_typo"], + needs_pending=True, + comments=[ + COMMENTS["marker_in_text_of_new_comment"], + ], + ), + _pr( + number=307, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["pending"], + backport_pending_in_labels=True, + needs_reminder=True, + comments=[ + COMMENTS["marker_in_old_comment_difficult"], + ], + ), + _pr( + number=308, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["versioned_typo"], + needs_pending=True, + comments=[ + COMMENTS["old_reminder"], + COMMENTS["strange_new_comment"], + ], + ), + _pr( + number=309, + merged_at="2023-10-01T12:00:00Z", + needs_pending=True, + comments=COMMENTS["120_old_comments"], + ), + # PRs marked for removal of pending label + _pr(number=401, merged_at="2023-10-01T12:00:00Z", needs_pending=True, remove=True), + _pr( + number=402, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["pending"], + backport_pending_in_labels=True, + needs_reminder=True, + comments=[ + COMMENTS["really_old_reminder"], + ], + remove=True, + ), + _pr( + number=403, + merged_at="2023-10-01T12:00:00Z", + labels=LABELS["pending_typo"], + needs_pending=True, + comments=[ + COMMENTS["old_reminder"], + COMMENTS["strange_new_comment"], + COMMENTS["really_old_reminder"], + ], + remove=True, + ), +] + + +# ----------------------- Selectors ----------------------- +def select_pull_requests(**filters: Any) -> list[PullRequestCase]: + """Return PullRequestCase objects matching direct attribute equality filters. + + Example: select_pull_requests(merged=True, needs_pending=True) + Only attributes present on PullRequestCase are supported. Unknown keys raise ValueError. + lists (e.g. labels) match by direct equality. + """ + if not filters: + return list(PR_CASES) + unsupported = [k for k in filters.keys() if not hasattr(PR_CASES[0], k)] if PR_CASES else [] + if unsupported: + raise ValueError(f"Unsupported filter keys: {unsupported}") + out: list[PullRequestCase] = [] + for pr in PR_CASES: + keep = True + for k, v in filters.items(): + k_val = getattr(pr, k) + if k_val != v: + if isinstance(k_val, list): + if not any(item in v for item in k_val): + keep = False + break + else: + keep = False + break + if keep: + out.append(pr) + return out + + +def case_by_number(number: int) -> PullRequestCase: + return next(pr for pr in PR_CASES if pr.number == number) + + +def select_pull_requests_by_lookback(lookback_days: int, **filters) -> list[PullRequestCase]: + """Return PullRequestCase objects merged within lookback_days from NOW.""" + filtered_prs = select_pull_requests(**filters) + now = convert_str_to_date(NOW) + out: list[PullRequestCase] = [] + for pr in filtered_prs: + if pr.merged and pr.merged_at: + merged_at_date = convert_str_to_date(pr.merged_at) + if merged_at_date >= lookback_cutoff(now, lookback_days): + out.append(pr) + return out + + +# ----------------------- Test case utilities ----------------------- +class GHInteractAction(Enum): + PR_ADD_PENDING_LABEL = "add_pending_label" + PR_REMOVE_PENDING_LABEL = "remove_pending_label" + PR_GET_COMMENTS = "get_comments" + PR_POST_REMINDER_COMMENT = "post_reminder_comment" + REPO_GET_LABELS = "get_repo_labels" + REPO_ADD_LABEL = "add_repo_label" + LIST_PRS = "list_prs" + + +def build_gh_routes_comments(method: str, prs: list[PullRequestCase]) -> list[GHRoute]: + routes = [] + for pr in prs: + if method == "POST": + routes.append( + GHRoute( + f"/repos/{TEST_REPO}/issues/{pr.number}/comments", + method=method, + response={}, + ) + ) + elif method == "GET": + comments_length = len(pr.comments) + if comments_length == 0: + routes.append( + GHRoute( + f"/repos/{TEST_REPO}/issues/{pr.number}/comments...", + method=method, + response=[], + ) + ) + continue + num_pages = (comments_length + COMMENTS_PER_PAGE - 1) // COMMENTS_PER_PAGE + + for page in range(1, num_pages + 1): + start_idx = (page - 1) * COMMENTS_PER_PAGE + end_idx = min(start_idx + COMMENTS_PER_PAGE, comments_length) + page_comments = pr.comments[start_idx:end_idx] + routes.append( + GHRoute( + f"/repos/{TEST_REPO}/issues/{pr.number}/comments...&page={page}", + method=method, + response=[asdict(comment) for comment in page_comments], + ) + ) + else: + raise ValueError(f"Unsupported method for comment routes: {method}") + return routes + + +def build_gh_routes_labels(method: str, prs: list[PullRequestCase]) -> list[GHRoute]: + routes = [] + for pr in prs: + routes.append( + GHRoute( + f"/repos/{TEST_REPO}/issues/{pr.number}/labels", + method=method, + response=[{"name": label.name, "color": label.color} for label in pr.labels], + ) + ) + return routes + + +def expected_actions_for_prs(action: GHInteractAction, prs: list[PullRequestCase]) -> list[tuple[str, str]]: + actions = [] + match action: + case GHInteractAction.PR_ADD_PENDING_LABEL: + for pr in prs: + actions.append(("POST", f"/repos/{TEST_REPO}/issues/{pr.number}/labels")) + case GHInteractAction.PR_REMOVE_PENDING_LABEL: + for pr in prs: + actions.append(("DELETE", f"/repos/{TEST_REPO}/issues/{pr.number}/labels")) + case GHInteractAction.PR_GET_COMMENTS: + for pr in prs: + num_pages = (len(pr.comments) + COMMENTS_PER_PAGE - 1) // COMMENTS_PER_PAGE + if num_pages == 0: + actions.append(("GET", f"/repos/{TEST_REPO}/issues/{pr.number}/comments?per_page={COMMENTS_PER_PAGE}&page=1")) + for page in range(1, num_pages + 1): + actions.append(("GET", f"/repos/{TEST_REPO}/issues/{pr.number}/comments?per_page={COMMENTS_PER_PAGE}&page={page}")) + case GHInteractAction.PR_POST_REMINDER_COMMENT: + for pr in prs: + actions.append(("POST", f"/repos/{TEST_REPO}/issues/{pr.number}/comments")) + case GHInteractAction.LIST_PRS: + actions.append(("GET", f"/search/issues...merged...updated...")) + case _: + raise ValueError(f"Unsupported PR action: {action}") + return actions + + +def expected_actions_for_repo(action: GHInteractAction) -> list[tuple[str, str]]: + actions = [] + match action: + case GHInteractAction.REPO_GET_LABELS: + actions.append(("GET", f"/repos/{TEST_REPO}/labels?per_page={SEARCH_LABELS_PER_PAGE}")) + case GHInteractAction.REPO_ADD_LABEL: + actions.append(("POST", f"/repos/{TEST_REPO}/labels")) + case _: + raise ValueError(f"Unsupported repo action: {action}") + return actions diff --git a/github_ci_tools/tests/resources/cases.py b/github_ci_tools/tests/resources/cases.py new file mode 100644 index 000000000..cac1e5162 --- /dev/null +++ b/github_ci_tools/tests/resources/cases.py @@ -0,0 +1,174 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from dataclasses import dataclass, field +from typing import Any, Callable, TypeVar + +import pytest + +from github_ci_tools.tests.utils import ( + STATIC_ROUTES, + TEST_REPO, + Comment, + GHRoute, + Label, +) + + +@dataclass +class PullRequestCase: + """Represents a single pull request test scenario. Also used to hold PR state. + + Fields: + labels: Simple list of label names. + comments: Optional list of issue comment dicts (each with at least 'body' and 'created_at'). + Used for testing reminder logic without separate fixtures. + number / merged / merged_at: Basic PR metadata. + remove: Flag you can pass through to label command tests. + """ + + number: int = 42 + labels: list[Label] = field(default_factory=list) + comments: list[Comment] = field(default_factory=list) + needs_pending: bool = False + needs_reminder: bool = False + backport_pending_in_labels: bool = False + merged: bool = True + merged_at: str | None = None + remove: bool = False # To select PRs for label removal. + + def __eq__(self, other: Any) -> bool: + if isinstance(other, dict): + other = PullRequestCase(**other) + return ( + self.number == other.number + and self.labels == other.labels + and self.merged == other.merged + and self.merged_at == other.merged_at + ) + + +@dataclass +class RepoCase: + """Repository-level label scenario. Also used to hold the repository state. + + repo_labels: Names of labels currently defined in the repository. + Used to drive repo_needs_pending_label / ensure_backport_pending_label behavior. + create_raises: Simulate a failure when attempting to create the label. + """ + + name: str = TEST_REPO + prs: list[PullRequestCase] = field(default_factory=list) + repo_labels: list[Label] = field(default_factory=list) + needs_pending: bool = False + create_raises: bool = False + + @property + def repo(self) -> str: + return self.name + + def register(self, gh_mock: Any) -> None: # 'Any' to avoid circular import of GitHubMock + """Register all declared routes on the provided gh_mock instance.""" + static_get_labels = STATIC_ROUTES["get_labels"] + static_create_pending_label = STATIC_ROUTES["create_pending_label"] + existing = [{"name": label.name, "color": label.color} for label in self.repo_labels] + + gh_mock.add("GET", static_get_labels.path, response=existing) + if self.needs_pending: + if self.create_raises: + gh_mock.add(static_create_pending_label.method, static_create_pending_label.path, exception=RuntimeError("fail create")) + else: + gh_mock.add( + static_create_pending_label.method, static_create_pending_label.path, response=static_create_pending_label.response + ) + + pass + + +# ---------------- Unified Interaction Case ----------------- +@dataclass +class GHInteractionCase: + """Unified scenario combining PR / Repo data and expected GitHub interactions. + + This is an optional higher-level abstraction that can replace separate + PullRequestCase + ad-hoc route registration in tests. It is intentionally kept + lightweight so existing tests can migrate incrementally. + + Fields: + prs: List of PullRequestCase objects (supporting multi-PR scenarios like bulk searches). + repo: RepoCase (repository label logic) - unused for current label tests. + routes: Pre-declared list of GHRoute entries to register on gh_mock. + expected_order: Ordered list of (method, path) tuples expected to appear (subsequence); may be empty. + strict: When True, assert call count equals expected_order length. + """ + + repo: RepoCase = field(default_factory=RepoCase) + routes: list[GHRoute] = field(default_factory=list) + lookback_days: int = 7 + pending_reminder_age_days: int = 7 + expected_prefetch_prs: list[dict[str, Any]] | None = field(default_factory=list) + expected_order: list[tuple[str, str]] = field(default_factory=list) + strict: bool = True + raises_error: type[Exception] | None = None + + def register(self, gh_mock: Any) -> None: # 'Any' to avoid circular import of GitHubMock + """Register all declared routes on the provided gh_mock instance.""" + if self.repo: + self.repo.register(gh_mock) + for r in self.routes: + gh_mock.add(method=r.method, path=r.path, response=r.response, status=r.status, exception=r.exception) + + +# ---------------- Backport CLI Definitions ----------------- +@dataclass +class BackportCliCase: + """Table-driven CLI scenario data container.""" + + argv: list[str] = field(default_factory=list) + env: dict[str, str] = field(default_factory=lambda: {"BACKPORT_TOKEN": "tok", "GITHUB_REPOSITORY": TEST_REPO}) + delete_env: list[str] = field(default_factory=list) + expect_parse_exit: bool = False + expect_require_error_substr: str | None = None + expected_config: dict[str, Any] = field(default_factory=dict) + expected_args: dict[str, Any] = field(default_factory=dict) + expected_log_level: int | None = None # When None, log level assertion is skipped. + gh_interaction: GHInteractionCase = field(default_factory=GHInteractionCase) + raises_error: type[Exception] | None = None + + +C = TypeVar("C") + + +def cases(arg_name: str = "case", **table: C) -> Callable: + """cases defines a decorator wrapping `pytest.mark.parametrize` to run a test against multiple cases. + + The purpose of this decorator is to create table-driven unit tests (https://go.dev/wiki/TableDrivenTests). + + param arg_name: the name of the parameter used for the input table (by default is 'case'). + :param table: + a dictionary of per use case entries that represent the test input table. It typically contains + either input parameters and configuration for initial test case status (or fixtures) + :return: a test method decorator. + + Usage: + @cases( + no_labels=PullRequestCase(labels=[], needs_pending=True), + ) + def test_create(case): + assert backport_mod.pr_needs_pending_label(backport_mod.PRInfo.from_dict(asdict(case))) is case.needs_pending + """ + return pytest.mark.parametrize(argnames=arg_name, argvalues=list(table.values()), ids=list(table.keys())) diff --git a/github_ci_tools/tests/test_backport_cli.py b/github_ci_tools/tests/test_backport_cli.py new file mode 100644 index 000000000..ddc982e49 --- /dev/null +++ b/github_ci_tools/tests/test_backport_cli.py @@ -0,0 +1,254 @@ +import json +import logging +import sys +from dataclasses import asdict + +import pytest + +from github_ci_tools.tests.resources.case_registry import ( + GHInteractAction, + build_gh_routes_labels, + case_by_number, + expected_actions_for_prs, + expected_actions_for_repo, + select_pull_requests, + select_pull_requests_by_lookback, +) +from github_ci_tools.tests.resources.cases import ( + BackportCliCase, + GHInteractionCase, + RepoCase, + cases, +) +from github_ci_tools.tests.utils import TEST_REPO, GHRoute + + +@cases( + label_basic=BackportCliCase( + argv=["backport.py", "--dry-run", "-vv", "label", "--lookback-days", "30"], + env={"BACKPORT_TOKEN": "tok"}, + expected_args={"command": "label", "lookback_days": 30, "dry_run": True, "verbose": 2}, + expected_config={"repo": TEST_REPO, "dry_run": True, "command": "label", "verbose": 2, "quiet": 0}, + expected_log_level=logging.NOTSET, + ), + label_default_lookback=BackportCliCase( + argv=["backport.py", "label"], + env={"BACKPORT_TOKEN": "tok"}, + expected_args={"command": "label", "lookback_days": 7}, + expected_config={"repo": TEST_REPO, "command": "label", "verbose": 0, "quiet": 0}, + expected_log_level=logging.INFO, + ), + label_override_lookback=BackportCliCase( + argv=["backport.py", "label", "--lookback-days", "45"], + env={"BACKPORT_TOKEN": "tok"}, + expected_args={"command": "label", "lookback_days": 45}, + expected_config={"repo": TEST_REPO, "command": "label", "verbose": 0, "quiet": 0}, + expected_log_level=logging.INFO, + ), + remind_basic=BackportCliCase( + argv=["backport.py", "remind", "--lookback-days", "10", "--pending-reminder-age-days", "5"], + env={"BACKPORT_TOKEN": "tok", "GITHUB_REPOSITORY": TEST_REPO}, + expected_args={"command": "remind", "lookback_days": 10, "pending_reminder_age_days": 5}, + expected_config={"repo": TEST_REPO, "command": "remind", "verbose": 0, "quiet": 0}, + expected_log_level=logging.INFO, + ), + remind_default_pending_age=BackportCliCase( + argv=["backport.py", "remind"], + env={"BACKPORT_TOKEN": "tok"}, + expected_args={"command": "remind", "lookback_days": 7, "pending_reminder_age_days": 14}, + expected_config={"repo": TEST_REPO, "command": "remind", "verbose": 0, "quiet": 0}, + expected_log_level=logging.INFO, + ), + remind_override_pending_age=BackportCliCase( + argv=["backport.py", "remind", "--lookback-days", "3", "--pending-reminder-age-days", "14"], + env={"BACKPORT_TOKEN": "tok"}, + expected_args={"command": "remind", "lookback_days": 3, "pending_reminder_age_days": 14}, + expected_config={"repo": TEST_REPO, "command": "remind", "verbose": 0, "quiet": 0}, + expected_log_level=logging.INFO, + ), + missing_command=BackportCliCase( + argv=["backport.py"], + env={"BACKPORT_TOKEN": "tok", "GITHUB_REPOSITORY": "acme/repo"}, + expect_parse_exit=True, + ), + missing_token=BackportCliCase( + argv=["backport.py", "label"], + delete_env=["BACKPORT_TOKEN"], + expect_require_error_substr="Missing BACKPORT_TOKEN", + ), + missing_repo=BackportCliCase( + argv=["backport.py", "label"], + delete_env=["GITHUB_REPOSITORY"], + expect_require_error_substr="Missing or invalid GITHUB_REPOSITORY", + ), +) +def test_backport_cli_parsing(backport_mod, monkeypatch, case: BackportCliCase): + # Environment setup + for k, v in case.env.items(): + monkeypatch.setenv(k, v) + for k in case.delete_env: + monkeypatch.delenv(k, raising=False) + monkeypatch.setattr(sys, "argv", case.argv) + + if case.expect_parse_exit: + with pytest.raises(SystemExit): + backport_mod.parse_args() + return + + args = backport_mod.parse_args() + + # Validate argparse Namespace expectations + for key, expected in case.expected_args.items(): + assert getattr(args, key) == expected + + # Attempt configure (which calls require_mandatory_vars) + if case.expect_require_error_substr: + with pytest.raises(RuntimeError) as exc: + backport_mod.configure(args) + assert case.expect_require_error_substr in str(exc.value) + return + + backport_mod.configure(args) + + # Validate CONFIG state + for key, expected in case.expected_config.items(): + assert getattr(backport_mod.CONFIG, key) == expected + + # Optional log level assertion provided by test case (error cases may skip) + if case.expected_log_level is not None: + if isinstance(case.expected_log_level, int): + assert backport_mod.CONFIG.log_level == case.expected_log_level + else: + case.expected_log_level(backport_mod) + + +@cases( + merged_recently=GHInteractionCase( + repo=RepoCase(prs=[case_by_number(101)]), + lookback_days=7, + expected_prefetch_prs=[asdict(case_by_number(101))], + ), + merged_old=GHInteractionCase( + repo=RepoCase(prs=[case_by_number(108)]), + lookback_days=10, + expected_prefetch_prs=None, + ), + merged_really_old_but_still_in_window=GHInteractionCase( + repo=RepoCase(prs=[case_by_number(301)]), + lookback_days=1200, + expected_prefetch_prs=[asdict(case_by_number(301))], + ), + unmerged_raises_error=GHInteractionCase( + repo=RepoCase(prs=[case_by_number(202)]), + lookback_days=1200, + expected_prefetch_prs=None, + raises_error=RuntimeError, + ), +) +def test_prefetch_prs_in_single_pr_mode(backport_mod, event_file, case: GHInteractionCase): + # Prepare event payload file for single PR mode + payload = {"pull_request": asdict(case.repo.prs[0])} + event_file.write_text(json.dumps(payload), encoding="utf-8") + + # Prefetched PRs must be one, None or raise error + if case.raises_error: + with pytest.raises(case.raises_error): + prefetched_prs = backport_mod.prefetch_prs(pr_mode=True, lookback_days=case.lookback_days) + return + prefetched_prs = backport_mod.prefetch_prs(pr_mode=True, lookback_days=case.lookback_days) + if prefetched_prs: + assert len(prefetched_prs) == 1 + prefetched_pr = prefetched_prs if prefetched_prs else None + + assert prefetched_pr == case.expected_prefetch_prs + + +@cases( + adds_repo_label_and_labels_only_w=BackportCliCase( + argv=["backport.py", "label"], + gh_interaction=GHInteractionCase( + repo=RepoCase(repo_labels=[], prs=select_pull_requests()), + lookback_days=7, + expected_prefetch_prs=[asdict(pr) for pr in select_pull_requests_by_lookback(7)], + routes=[ + GHRoute( + path=f"/search/issues...merged...updated...", + method="GET", + response={"items": [asdict(pr) for pr in select_pull_requests_by_lookback(7)]}, + ), + *build_gh_routes_labels("GET", select_pull_requests_by_lookback(7)), + *build_gh_routes_labels("POST", select_pull_requests_by_lookback(7)), + ], + expected_order=[ + *expected_actions_for_prs(GHInteractAction.LIST_PRS, select_pull_requests_by_lookback(7)), + *expected_actions_for_repo(GHInteractAction.REPO_GET_LABELS), + *expected_actions_for_repo(GHInteractAction.REPO_ADD_LABEL), + *expected_actions_for_prs(GHInteractAction.PR_ADD_PENDING_LABEL, select_pull_requests_by_lookback(7)), + ], + ), + ), + reminds_those_within_pending=BackportCliCase( + argv=["backport.py", "remind", "--lookback-days", "7", "--pending-reminder-age-days", "30"], + gh_interaction=GHInteractionCase( + # Has all the PRs + repo=RepoCase(prs=select_pull_requests()), + lookback_days=7, + pending_reminder_age_days=30, + expected_prefetch_prs=[asdict(pr) for pr in select_pull_requests_by_lookback(7)], + routes=[ + GHRoute( + path=f"/search/issues...merged...updated...", + method="GET", + # Prefetches only within 7 days (lookback) + response={"items": [asdict(pr) for pr in select_pull_requests_by_lookback(7)]}, + ), + *build_gh_routes_labels("GET", select_pull_requests_by_lookback(7)), + ], + expected_order=[ + # Actions are dynamically created based on the needs_pending and needs_reminder flags + *expected_actions_for_prs(GHInteractAction.LIST_PRS, select_pull_requests_by_lookback(7)), + ], + ), + ), +) +def test_backport_run(backport_mod, gh_mock, monkeypatch, case: BackportCliCase): + """Basic sanity test of run_backport_cli.""" + case.gh_interaction.register(gh_mock) + + # Environment setup + for k, v in case.env.items(): + monkeypatch.setenv(k, v) + for k in case.delete_env: + monkeypatch.delenv(k, raising=False) + monkeypatch.setattr(sys, "argv", case.argv) + + args = backport_mod.parse_args() + backport_mod.configure(args) + + prefetched = backport_mod.prefetch_prs(args.pr_mode, args.lookback_days) + try: + match args.command: + case "label": + result = backport_mod.run_label(prefetched, args.remove) + case "remind": + result = backport_mod.run_remind( + prefetched, + args.pending_reminder_age_days, + args.lookback_days, + ) + for pr in prefetched: + if pr.get("needs_pending", False) is False: + case.gh_interaction.expected_order += expected_actions_for_prs( + GHInteractAction.PR_GET_COMMENTS, [case_by_number(pr.get("number"))] + ) + if pr.get("needs_reminder", False): + case.gh_interaction.expected_order += expected_actions_for_prs( + GHInteractAction.PR_POST_REMINDER_COMMENT, [case_by_number(pr.get("number"))] + ) + case _: + pytest.fail(f"Unknown command {args.command}") + except Exception as e: + pytest.fail(f"backport_run raised unexpected exception: {e}") + + assert result == 0 + gh_mock.assert_calls_in_order(*case.gh_interaction.expected_order) diff --git a/github_ci_tools/tests/test_backport_label.py b/github_ci_tools/tests/test_backport_label.py new file mode 100644 index 000000000..7034ec5f5 --- /dev/null +++ b/github_ci_tools/tests/test_backport_label.py @@ -0,0 +1,76 @@ +from dataclasses import asdict + +from github_ci_tools.tests.resources.case_registry import ( + GHInteractAction, + build_gh_routes_labels, + case_by_number, + expected_actions_for_prs, + select_pull_requests, +) +from github_ci_tools.tests.resources.cases import GHInteractionCase, RepoCase, cases +from github_ci_tools.tests.utils import LABELS, STATIC_ROUTES + + +@cases( + exists_dont_create=RepoCase(repo_labels=LABELS["pending"]), + no_label_repo_creates=RepoCase(repo_labels=[], needs_pending=True), + no_label_but_gh_error=RepoCase(repo_labels=[], needs_pending=True, create_raises=True), + ignore_duplicate_pending=RepoCase(repo_labels=LABELS["pending_duplicate"]), + only_backport_label_creates=RepoCase(repo_labels=LABELS["backport"], needs_pending=True), + labels_with_pending_typo_creates=RepoCase(repo_labels=LABELS["pending_typo"], needs_pending=True), + labels_with_backport_typo_creates=RepoCase(repo_labels=LABELS["backport_typo"], needs_pending=True), +) +def test_repo_ensure_backport_pending_label(backport_mod, gh_mock, caplog, case: RepoCase): + """Ensure creation only when PENDING_LABEL is strictly absent.""" + static_get_labels = STATIC_ROUTES["get_labels"] + static_create_pending_label = STATIC_ROUTES["create_pending_label"] + existing = [{"name": label.name, "color": label.color} for label in case.repo_labels] + + gh_mock.add("GET", static_get_labels.path, response=existing) + + if case.needs_pending: + if case.create_raises: + gh_mock.add(static_create_pending_label.method, static_create_pending_label.path, exception=RuntimeError("fail create")) + else: + gh_mock.add(static_create_pending_label.method, static_create_pending_label.path, response=static_create_pending_label.response) + + backport_mod.ensure_backport_pending_label() + assertions = [(static_get_labels.method, static_get_labels.path)] + if case.needs_pending: + assertions.append((static_create_pending_label.method, static_create_pending_label.path)) + gh_mock.assert_calls_in_order(*assertions) + if case.create_raises and case.needs_pending: + assert any(f"{backport_mod.COULD_NOT_CREATE_LABEL_WARNING}" in rec.message for rec in caplog.records) + + +@cases( + add_to_single_pr_with_no_label=GHInteractionCase( + repo=RepoCase(prs=[case_by_number(101)]), + routes=[ + *build_gh_routes_labels("POST", [case_by_number(101)]), + ], + ), + add_pull_request_label_only_to_those_needs_pending=GHInteractionCase( + repo=RepoCase(prs=select_pull_requests(remove=False)), + routes=[*build_gh_routes_labels("POST", select_pull_requests(remove=False))], + ), + remove_pull_request_label_only_for_those_that_has_pending=GHInteractionCase( + repo=RepoCase(prs=select_pull_requests(remove=True)), + routes=[*build_gh_routes_labels("DELETE", select_pull_requests(remove=True))], + ), +) +def test_label_logic(backport_mod, gh_mock, case: GHInteractionCase): + """Test of the exact logic as in run_label.""" + case.register(gh_mock) + for pr in case.repo.prs: + # Test of the exact logic as in run_label + pr_info = backport_mod.PRInfo.from_dict(asdict(pr)) + assert backport_mod.pr_needs_pending_label(pr_info) is pr.needs_pending + + if pr.remove: + backport_mod.remove_pull_request_label(pr.number, backport_mod.PENDING_LABEL) + case.expected_order += expected_actions_for_prs(GHInteractAction.PR_REMOVE_PENDING_LABEL, [case_by_number(pr.number)]) + elif pr.needs_pending: + backport_mod.add_pull_request_label(pr.number, backport_mod.PENDING_LABEL) + case.expected_order += expected_actions_for_prs(GHInteractAction.PR_ADD_PENDING_LABEL, [case_by_number(pr.number)]) + gh_mock.assert_calls_in_order(*case.expected_order) diff --git a/github_ci_tools/tests/test_backport_reminder.py b/github_ci_tools/tests/test_backport_reminder.py new file mode 100644 index 000000000..c0009627b --- /dev/null +++ b/github_ci_tools/tests/test_backport_reminder.py @@ -0,0 +1,128 @@ +from dataclasses import asdict + +from github_ci_tools.tests.resources.case_registry import ( + GHInteractAction, + build_gh_routes_comments, + case_by_number, + expected_actions_for_prs, + select_pull_requests, +) +from github_ci_tools.tests.resources.cases import ( + GHInteractionCase, + PullRequestCase, + RepoCase, + cases, +) +from github_ci_tools.tests.utils import COMMENT_MARKER_BASE + + +@cases( + pr_with_no_comments=case_by_number(201), + pr_with_2_reminders=case_by_number(302), + pr_with_120_old_comments=case_by_number(309), + pr_with_really_old_reminder=case_by_number(303), + pr_with_no_reminders=case_by_number(203), +) +def test_last_reminder_time(backport_mod, case: PullRequestCase): + """Test determining the last reminder time from issue comments.""" + last_reminder = backport_mod.last_reminder_time([asdict(comment) for comment in case.comments], backport_mod.COMMENT_MARKER_BASE) + expected_reminders = [comment for comment in case.comments if comment.is_reminder] + if expected_reminders: + expected_last_reminder = max(expected_reminders, key=lambda c: c.created_at_dt()).created_at_dt() + assert last_reminder == expected_last_reminder + else: + assert last_reminder is None + + +@cases( + from_single_pr_with_no_comments=GHInteractionCase( + repo=RepoCase( + prs=[case_by_number(201)], + ), + routes=[ + *build_gh_routes_comments("GET", [case_by_number(201)]), + ], + expected_order=[ + *expected_actions_for_prs(GHInteractAction.PR_GET_COMMENTS, [case_by_number(201)]), + ], + ), + from_single_pr_with_120_old_comments=GHInteractionCase( + repo=RepoCase( + prs=[case_by_number(309)], + ), + routes=[ + *build_gh_routes_comments("GET", [case_by_number(309)]), + ], + expected_order=[ + *expected_actions_for_prs(GHInteractAction.PR_GET_COMMENTS, [case_by_number(309)]), + ], + ), + fetch_from_all_prs=GHInteractionCase( + repo=RepoCase(prs=select_pull_requests()), + routes=[ + *build_gh_routes_comments("GET", select_pull_requests()), + ], + expected_order=[ + *expected_actions_for_prs(GHInteractAction.PR_GET_COMMENTS, select_pull_requests()), + ], + ), +) +def test_get_issue_comments(backport_mod, gh_mock, case: GHInteractionCase): + """Test fetching issue comments with pagination.""" + case.register(gh_mock) + for pr in case.repo.prs: + total_comments = len(pr.comments) + fetched_comments = backport_mod.get_issue_comments(pr.number) + assert len(fetched_comments) == total_comments + for comment in fetched_comments: + assert comment in [asdict(c) for c in pr.comments] + gh_mock.assert_calls_in_order(*case.expected_order) + + +@cases( + # pr_that_has_no_pending_label_does_not_get_commented=GHInteractionCase( + # repo=RepoCase(prs=[case_by_number(101)]), + # routes=[ + # *build_gh_routes_comments("GET", [case_by_number(101)]), + # ], + # ), + # pr_has_pending_label_and_needs_reminder_gets_one=GHInteractionCase( + # repo=RepoCase(prs=[case_by_number(108)]), + # routes=[ + # *build_gh_routes_comments("GET", [case_by_number(108)]), + # *build_gh_routes_comments("POST", [case_by_number(108)]), + # ], + # ), + all_prs_have_pending_label_and_needs_reminder_get_one=GHInteractionCase( + repo=RepoCase(prs=select_pull_requests(backport_pending_in_labels=True, needs_reminder=True)), + routes=[ + *build_gh_routes_comments("GET", select_pull_requests(backport_pending_in_labels=True, needs_reminder=True)), + *build_gh_routes_comments("POST", select_pull_requests(backport_pending_in_labels=True, needs_reminder=True)), + ], + ), + # all_prs_not_need_pending_and_has_reminder_does_not_get_one=GHInteractionCase( + # repo=RepoCase(prs=select_pull_requests(backport_pending_in_labels=False, needs_reminder=False)), + # routes=[ + # *build_gh_routes_comments("GET", select_pull_requests(backport_pending_in_labels=False, needs_reminder=False)), + # ], + # ), +) +def test_remind_logic(backport_mod, gh_mock, case: GHInteractionCase): + """Test of the exact logic as in run_label.""" + case.register(gh_mock) + threshold = backport_mod.dt.datetime.now(backport_mod.dt.timezone.utc) - backport_mod.dt.timedelta(days=case.pending_reminder_age_days) + for pr in case.repo.prs: + # Test of the exact logic as in run_remind. + + needs_reminder = backport_mod.pr_needs_reminder(backport_mod.PRInfo.from_dict(asdict(pr)), threshold) + assert needs_reminder is pr.needs_reminder + + if needs_reminder: + backport_mod.add_comment(pr.number, f"{COMMENT_MARKER_BASE}") + + if pr.needs_pending is False: + case.expected_order += expected_actions_for_prs(GHInteractAction.PR_GET_COMMENTS, [case_by_number(pr.number)]) + if pr.needs_reminder: + case.expected_order += expected_actions_for_prs(GHInteractAction.PR_POST_REMINDER_COMMENT, [case_by_number(pr.number)]) + + gh_mock.assert_calls_in_order(*case.expected_order) diff --git a/github_ci_tools/tests/utils.py b/github_ci_tools/tests/utils.py new file mode 100644 index 000000000..31a7110ed --- /dev/null +++ b/github_ci_tools/tests/utils.py @@ -0,0 +1,167 @@ +# Licensed to Elasticsearch B.V. under one or more contributor +# license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright +# ownership. Elasticsearch B.V. licenses this file to you under +# the Apache License, Version 2.0 (the "License"); you may +# not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import datetime as dt +from dataclasses import dataclass, field +from typing import Any + + +# ------------------- Date helpers ----------------- +def convert_str_to_date(date_str: str) -> dt.datetime: + """Convert dates in ISO 8601 to datetime object.""" + try: + return dt.datetime.strptime(date_str, ISO_FORMAT).replace(tzinfo=dt.timezone.utc) + except ValueError as e: + raise RuntimeError(f"Invalid date format: {date_str}") from e + + +def lookback_cutoff(now, lookback: int) -> dt.datetime: + """Return the cutoff datetime for a given lookback period in days.""" + return now - dt.timedelta(days=lookback) + + +# ---------------- GitHub Helper Definitions ----------------- +@dataclass +class Label: + """Represents a single label.""" + + name: str = field(default_factory=str) + color: str = field(default="ffffff") + + def __eq__(self, other: Any) -> bool: + if not isinstance(other, Label): + return NotImplemented + return self.name == other.name + + def get(self, key: str, default: str = "") -> str: + return getattr(self, key, default) + + +@dataclass +class Comment: + """Represents a single issue comment. + + Fields: + body: The text content of the comment. + created_at: ISO 8601 timestamp string representing when the comment was created. + """ + + body: str + created_at: str + is_reminder: bool = False + + def created_at_dt(self) -> dt.datetime: + """Return the `created_at` value parsed as a timezone-aware UTC datetime. + Raises: + RuntimeError: If `created_at` is missing or not in the expected ISO 8601 format. + """ + if not self.created_at: + raise RuntimeError("Missing created_at field") + try: + return dt.datetime.strptime(self.created_at, ISO_FORMAT).replace(tzinfo=dt.timezone.utc) + except ValueError as e: + raise RuntimeError(f"Invalid created_at format: {self.created_at}") from e + + def get(self, key: str, default: str = "") -> str: + return getattr(self, key, default) + + +@dataclass +class GHRoute: + """Single GitHub route definition for a scenario. + + Fields: + path: Fully expanded path expected (query params included if any). + method: HTTP method (default GET). + response: JSON (dict or list) returned by mock. + status: HTTP status (>=400 triggers RuntimeError in mock invocation) + exception: If set, raised instead of using status/response. + """ + + path: str + method: str = "GET" + response: dict[str, Any] | list[dict[str, Any]] = field(default_factory=dict) + status: int = 200 + exception: BaseException | None = None + + +# ------------------- Constants ----------------- + +TEST_REPO = "test/repo" + +SEARCH_LABELS_PER_PAGE = 100 +SEARCH_ISSUES_PER_PAGE = 100 + +# We define a NOW constant for consistent use in tests. +NOW = "2025-10-30T12:00:00Z" +ISO_FORMAT = "%Y-%m-%dT%H:%M:%SZ" + +AGES = { + "old_7_days": dt.datetime.strptime("2025-10-23T12:00:00Z", ISO_FORMAT).replace(tzinfo=dt.timezone.utc), + "old_14_days": dt.datetime.strptime("2025-10-16T12:00:00Z", ISO_FORMAT).replace(tzinfo=dt.timezone.utc), + "really_old": dt.datetime.strptime("2023-10-01T12:00:00Z", ISO_FORMAT).replace(tzinfo=dt.timezone.utc), +} + +# Note that we should not import from the backport module directly to avoid circular imports. +PENDING_LABEL = "backport pending" +PENDING_LABEL_COLOR = "fff2bf" + +LABELS = { + "pending": [Label(PENDING_LABEL)], + "pending_typo": [Label(name) for name in ["backport pend", "Backport Pending", "BACKPORT PENDING"]], + "pending_duplicate": [Label(name) for name in [PENDING_LABEL, PENDING_LABEL]], + "backport": [Label(name) for name in ["backport"]], + "backport_typo": [Label(name) for name in ["backprt", "back-port", "Backport"]], + "versioned": [Label(name) for name in ["v9.2"]], + "versioned_typo": [Label(name) for name in ["v9.2124215s", "123.v2.1sada", "version9.2", "..v9.2..", "v!@#9.20%^@"]], + "versioned_pending": [Label(name) for name in ["v9.2", PENDING_LABEL]], # for remove tests + "versioned_pending_typo": [Label(name) for name in ["v9.2", "backport pend"]], +} + +COMMENT_MARKER_BASE = "" +COMMENTS = { + "recent_reminder": Comment(f"{COMMENT_MARKER_BASE}\nThis is a recent reminder.", created_at="2025-10-23T12:00:00Z", is_reminder=True), + "old_reminder": Comment(f"{COMMENT_MARKER_BASE}\nThis is an old reminder.", created_at="2025-10-01T12:00:00Z", is_reminder=True), + "really_old_reminder": Comment( + f"\nThis is a really old reminder.{COMMENT_MARKER_BASE}", created_at="2023-10-01T12:00:00Z", is_reminder=True + ), + "old_comment": Comment("This is just a regular old comment without any markers.", created_at="2025-10-01T12:00:00Z", is_reminder=False), + "strange_new_comment": Comment( + "@!#%@!@# This is a strange comment without any markers. $$$%^&*()", created_at="2025-10-23T12:00:00Z", is_reminder=False + ), + "recent_comment": Comment( + "This is just a regular recent comment without any markers.", created_at="2025-10-23T12:00:00Z", is_reminder=False + ), + "marker_only_new_comment": Comment(COMMENT_MARKER_BASE, created_at="2025-10-23T12:00:00Z", is_reminder=False), + "marker_in_text_of_new_comment": Comment( + f"Please note: {COMMENT_MARKER_BASE} this is important.", created_at="2025-10-23T12:00:00Z", is_reminder=False + ), + "marker_in_old_comment_difficult": Comment( + f"sadcas@!#!@<<<<{COMMENT_MARKER_BASE}>>>>sadcas12!$@%!", created_at="2025-10-10T12:00:00Z", is_reminder=False + ), + "120_old_comments": [Comment(f"Comment number {i}", created_at="2025-10-01T12:00:00Z", is_reminder=False) for i in range(120)], +} +COMMENTS_PER_PAGE = 100 + + +STATIC_ROUTES = { + "create_pending_label": GHRoute( + path=f"/repos/{TEST_REPO}/labels", method="POST", response={"name": PENDING_LABEL, "color": PENDING_LABEL_COLOR} + ), + "get_labels": GHRoute(path=f"/repos/{TEST_REPO}/labels?per_page={SEARCH_LABELS_PER_PAGE}", method="GET", response=[]), + "search_issues": GHRoute(path=f"/search/issues...merged...updated...", method="GET", response={}), +}