diff --git a/tools/python/cherry_pick.py b/tools/python/cherry_pick.py new file mode 100644 index 0000000000000..96a91c5e74c22 --- /dev/null +++ b/tools/python/cherry_pick.py @@ -0,0 +1,318 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +""" +Cherry-Pick Helper Script +------------------------- +Description: + This script automates the process of cherry-picking commits for a release branch. + It fetches merged PRs with a specific label, sorts them by merge date, and generates: + 1. A batch file (.cmd) with git cherry-pick commands. + 2. A markdown file (.md) for the PR description. + It also checks for potential missing dependencies (conflicts) by verifying if files modified + by the cherry-picked commits have any other modifications in commits that are not in the + specified target branch and are not included in the cherry-pick list. + +Usage: + python cherry_pick.py --label "release:1.24.2" --output cherry_pick.cmd --branch "origin/rel-1.24.2" + +Requirements: + - Python 3.7+ + - GitHub CLI (gh) logged in. + - Git available in PATH. +""" + +import argparse +import json +import os +import sys + +from cherry_pick_utils import ( + check_preflight, + extract_pr_numbers, + get_pr_number_from_subject, + run_command, +) + + +def get_merged_prs(repo, label, limit=200): + """Fetch merged PRs with the specific label.""" + print(f"Fetching merged PRs with label '{label}' from {repo}...") + cmd = [ + "gh", + "pr", + "list", + "--repo", + repo, + "--label", + label, + "--state", + "merged", + "--json", + "number,title,mergeCommit,mergedAt", + "-L", + str(limit), + ] + output = run_command(cmd) + if not output: + return [] + + try: + return json.loads(output) + except json.JSONDecodeError as e: + print(f"Error parsing gh output: {e}", file=sys.stderr) + return [] + + +def get_changed_files(oid): + """Get list of files changed in a commit.""" + output = run_command(["git", "diff-tree", "--no-commit-id", "--name-only", "-m", "-r", oid], silent=True) + if output: + return output.strip().splitlines() + return [] + + +def sanitize_title(title): + """Normalize PR titles for single-line text output.""" + return title.replace("\n", " ").strip() + + +def escape_markdown_table_cell(text): + """Escape markdown table delimiters in cell content.""" + return sanitize_title(text).replace("|", "\\|") + + +def get_existing_pr_numbers(branch, repo=None, log_limit=500): + """Get the set of PR numbers already present in the target branch.""" + output = run_command(["git", "log", branch, "--oneline", "-n", str(log_limit)], silent=True) + if not output: + return set() + pr_numbers = set() + + # Pre-fetch PR cache to avoid redundant gh calls + pr_cache = {} + + # Process commit log + lines = output.strip().splitlines() + for line in lines: + parts = line.split(" ", 1) + if len(parts) < 2: + continue + subject = parts[1] + + pr_num = get_pr_number_from_subject(subject) + if not pr_num: + continue + + pr_num_int = int(pr_num) + pr_numbers.add(pr_num_int) + + # Check if it's a cherry-pick / meta-PR + is_meta_pr = ( + "cherry pick" in subject.lower() or "cherry-pick" in subject.lower() or "cherrypick" in subject.lower() + ) + + if is_meta_pr: + # Query gh to get more details (body/commits) to find squashed sub-PRs + if pr_num not in pr_cache: + gh_cmd = ["gh", "pr", "view", pr_num, "--json", "title,body,commits"] + if repo: + gh_cmd.extend(["--repo", repo]) + gh_out = run_command(gh_cmd, silent=True) + if gh_out: + try: + pr_cache[pr_num] = json.loads(gh_out) + except json.JSONDecodeError: + pr_cache[pr_num] = None + else: + pr_cache[pr_num] = None + + details = pr_cache.get(pr_num) + if details: + # Collect sub-PRs from title, body, and commits + extracted_nums = [] + extracted_nums.extend(extract_pr_numbers(details.get("title", ""), strict=True)) + extracted_nums.extend(extract_pr_numbers(details.get("body", ""), strict=True)) + + for commit in details.get("commits", []): + extracted_nums.extend(extract_pr_numbers(commit.get("messageHeadline", ""), strict=True)) + + for num in set(extracted_nums): + if num != pr_num_int: + pr_numbers.add(num) + + return pr_numbers + + +def check_missing_dependencies(prs, branch): + """Check for potential missing dependencies (conflicts).""" + print("\nChecking for potential missing dependencies (conflicts)...") + + # Collect OIDs being cherry-picked and all their ancestor commits + cherry_pick_oids = set() + for pr in prs: + if pr.get("mergeCommit"): + merge_oid = pr["mergeCommit"]["oid"] + cherry_pick_oids.add(merge_oid) + # Include ancestor commits of merge commits to avoid false-positive warnings + # for PRs that used a regular merge (not squash) strategy + ancestor_output = run_command(["git", "log", "--format=%H", merge_oid, "--not", branch], silent=True) + if ancestor_output: + for ancestor_oid in ancestor_output.strip().splitlines(): + cherry_pick_oids.add(ancestor_oid.strip()) + + conflicting_prs_count = 0 + for pr in prs: + if not pr.get("mergeCommit"): + continue + + oid = pr["mergeCommit"]["oid"] + number = pr["number"] + + files = get_changed_files(oid) + if not files: + continue + + # For each file, find commits that modified it between the target branch and the cherry-picked commit. + # Deduplicate warnings: group affected files by missing commit. + # missing_commits maps: missing_commit_oid -> {"title": ..., "files": [...]} + missing_commits = {} + + for filepath in files: + # git log --not -- + output = run_command(["git", "log", oid, "--not", branch, "--format=%H %s", "--", filepath], silent=True) + + if not output: + continue + + for line in output.strip().splitlines(): + parts = line.split(" ", 1) + c = parts[0] + title = parts[1] if len(parts) > 1 else "" + + if c == oid: + continue + if c not in cherry_pick_oids: + entry = missing_commits.setdefault(c, {"title": title, "files": []}) + if not entry["title"]: + entry["title"] = title + entry["files"].append(filepath) + + # Print deduplicated warnings + if missing_commits: + conflicting_prs_count += 1 + for missing_oid, entry in missing_commits.items(): + files_str = ", ".join(entry["files"]) + print( + f"WARNING: PR #{number} ({oid}) modifies files that were also changed by commit {missing_oid} ({entry['title']}), " + f"which is not in the cherry-pick list. This may indicate missing related changes. Affected files: {files_str}" + ) + + if conflicting_prs_count == 0: + print("No potential missing dependencies found.") + else: + print(f"\nDone. Found potential missing dependencies for {conflicting_prs_count} PRs.") + + +def main(): + parser = argparse.ArgumentParser(description="Generate cherry-pick script from PRs with a specific label.") + parser.add_argument("--label", required=True, help="Label to filter PRs") + parser.add_argument( + "--output", required=True, help="Output script file path (.sh for bash, .cmd for Windows batch)" + ) + parser.add_argument("--repo", default="microsoft/onnxruntime", help="Repository (default: microsoft/onnxruntime)") + parser.add_argument( + "--branch", default="HEAD", help="Target branch to compare against for dependency checks (default: HEAD)" + ) + parser.add_argument("--limit", type=int, default=200, help="Maximum number of PRs to fetch (default: 200)") + parser.add_argument( + "--md-output", + help="Output markdown file path for the PR description (default: next to --output)", + ) + args = parser.parse_args() + + # Preflight Check + if not check_preflight(): + return + + # 1. Fetch Merged PRs + prs = get_merged_prs(args.repo, args.label, args.limit) + + if not prs: + print(f"No PRs found with label '{args.label}'.") + return + + # Sort by mergedAt (ISO 8601 strings sort correctly in chronological order) + prs.sort(key=lambda x: x["mergedAt"]) + + # 1.5. Check which PRs are already in the target branch + existing_prs = get_existing_pr_numbers(args.branch, repo=args.repo) + if existing_prs: + print(f"Found {len(existing_prs)} PRs already in branch '{args.branch}'.") + + cherry_pick_prs = [] + skipped_count = 0 + for pr in prs: + number = pr["number"] + safe_title = sanitize_title(pr["title"]) + + if not pr.get("mergeCommit"): + print(f"Warning: PR #{number} has no merge commit OID. Skipping.", file=sys.stderr) + continue + + if number in existing_prs: + print(f"Skipping PR #{number} (already in branch '{args.branch}'): {safe_title}") + skipped_count += 1 + continue + + cherry_pick_prs.append(pr) + + # Determine output format based on file extension + is_shell = args.output.endswith(".sh") + + # 2. Write Output Script + commit_count = len(cherry_pick_prs) + with open(args.output, "w", encoding="utf-8") as f: + if is_shell: + f.write("#!/bin/bash\n") + f.write(f"# Cherry-pick {args.label} commits\n") + f.write("# Sorted by merge time (oldest first)\n") + f.write("set -e\n\n") + else: + f.write("@echo off\n") + f.write(f"rem Cherry-pick {args.label} commits\n") + f.write("rem Sorted by merge time (oldest first)\n\n") + + for pr in cherry_pick_prs: + number = pr["number"] + safe_title = sanitize_title(pr["title"]) + + oid = pr["mergeCommit"]["oid"] + comment = "#" if is_shell else "rem" + f.write(f"{comment} PR {number}: {safe_title}\n") + f.write(f"git cherry-pick {oid}\n\n") + + print(f"Generated {args.output} with {commit_count} commits ({skipped_count} skipped, already in branch).") + + # 3. Write PR Description Markdown (table format) + output_dir = os.path.dirname(args.output) + md_output = args.md_output or os.path.join(output_dir, "cherry_pick_pr_description.md") + with open(md_output, "w", encoding="utf-8") as f: + f.write("This cherry-picks the following commits for the release:\n\n") + f.write("| Commit ID | PR Number | Commit Title |\n") + f.write("|-----------|-----------|-------------|\n") + for pr in cherry_pick_prs: + number = pr["number"] + title = escape_markdown_table_cell(pr["title"]) + oid = pr["mergeCommit"]["oid"] + short_oid = oid[:10] + f.write(f"| {short_oid} | #{number} | {title} |\n") + + print(f"Generated {md_output} with {commit_count} commits.") + + # 4. Dependency Check + check_missing_dependencies(cherry_pick_prs, args.branch) + + +if __name__ == "__main__": + main() diff --git a/tools/python/cherry_pick_utils.py b/tools/python/cherry_pick_utils.py new file mode 100644 index 0000000000000..7460761593a8c --- /dev/null +++ b/tools/python/cherry_pick_utils.py @@ -0,0 +1,101 @@ +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. + +import re +import subprocess +import sys + + +def run_command(command_list, cwd=None, silent=False): + """Run a command using a list of arguments for security (no shell=True).""" + try: + result = subprocess.run(command_list, check=False, capture_output=True, text=True, cwd=cwd, encoding="utf-8") + if result.returncode != 0: + if not silent: + log_str = " ".join(command_list) + print(f"Error running command: {log_str}", file=sys.stderr) + if result.stderr: + print(f"Stderr: {result.stderr.strip()}", file=sys.stderr) + return None + return result.stdout + except FileNotFoundError: + if not silent: + cmd = command_list[0] + print(f"Error: '{cmd}' command not found.", file=sys.stderr) + if cmd == "gh": + print( + "Please install GitHub CLI (https://cli.github.com/) and ensure 'gh' is available on your PATH.", + file=sys.stderr, + ) + return None + except Exception as e: + if not silent: + print(f"Exception running command {' '.join(command_list)}: {e}", file=sys.stderr) + return None + + +def check_preflight(): + """Verify gh CLI and git repository early.""" + # Check git + git_check = run_command(["git", "rev-parse", "--is-inside-work-tree"], silent=True) + if not git_check: + print("Error: This script must be run inside a git repository.", file=sys.stderr) + return False + + # Check gh + gh_check = run_command(["gh", "--version"], silent=True) + if not gh_check: + print("Error: GitHub CLI (gh) not found or not in PATH.", file=sys.stderr) + print( + "Please install GitHub CLI (https://cli.github.com/) and ensure 'gh' is available on your PATH.", + file=sys.stderr, + ) + return False + + # gh auth status outputs to stderr, so run_command returns empty stdout even on success. + # Use subprocess directly to check the return code. + try: + auth_result = subprocess.run(["gh", "auth", "status"], capture_output=True, text=True, check=False) + if auth_result.returncode != 0: + print("Error: GitHub CLI not authenticated. Please run 'gh auth login'.", file=sys.stderr) + return False + except FileNotFoundError: + print("Error: GitHub CLI (gh) not found.", file=sys.stderr) + return False + + return True + + +def get_pr_number_from_subject(subject): + """Extract PR number from a commit subject like 'Some title (#12345)'.""" + match = re.search(r"\(#(\d+)\)$", subject.strip()) + if match: + return match.group(1) + return None + + +def extract_pr_numbers(text, strict=False): + if not text: + return [] + + if strict: + # Strict mode: Only look for (#123) with closing paren, full onnxruntime URLs, + # or PR numbers in markdown table cells (| #123 |), or standalone #123 with clear boundaries. + # This avoids noise from version numbers or external repo PRs + # And it avoids matching truncated headlines like (#25... as PR #25 + patterns = [ + r"\(#(\d+)\)", # (#123) + r"microsoft/onnxruntime/pull/(\d+)", + r"(?:^|\s|-)#(\d+)(?:\s|$)", # #123 at start, or preceded by space/dash, and followed by space or end + r"\|\s*#(\d+)\s*\|", # | #123 | (markdown table cell) + ] + + results = [] + for p in patterns: + results.extend(re.findall(p, text)) + return [int(x) for x in set(results)] + + # Matches patterns like #123 or https://github.com/microsoft/onnxruntime/pull/123 + # Also handles ( #123) or similar in titles + prs = re.findall(r"(?:#|/pull/)(\d+)", text) + return [int(x) for x in set(prs)] diff --git a/tools/python/compile_contributors.py b/tools/python/compile_contributors.py index 494b0f91c5381..bb02c2807d08c 100644 --- a/tools/python/compile_contributors.py +++ b/tools/python/compile_contributors.py @@ -30,7 +30,15 @@ import json import os import re -import subprocess + +from cherry_pick_utils import ( + check_preflight, + extract_pr_numbers, + run_command, +) +from cherry_pick_utils import ( + get_pr_number_from_subject as get_pr_number, +) def log_event(message, log_file=None): @@ -42,98 +50,9 @@ def log_event(message, log_file=None): log_file.write(full_message + "\n") -def run_command(command_list, cwd=".", silent=False): - """Run a command using a list of arguments for security (no shell=True).""" - result = subprocess.run(command_list, check=False, capture_output=True, text=True, cwd=cwd, encoding="utf-8") - if result.returncode != 0: - if not silent: - log_str = " ".join(command_list) - print(f"Error running command: {log_str}") - if result.stderr: - print(f"Stderr: {result.stderr.strip()}") - return None - return result.stdout - - -def check_preflight(): - """Verify gh CLI and git repository early.""" - # Check git - git_check = run_command(["git", "rev-parse", "--is-inside-work-tree"], silent=True) - if not git_check: - print("Error: This script must be run inside a git repository.") - return False - - # Check gh - gh_check = run_command(["gh", "--version"], silent=True) - if not gh_check: - print("Error: GitHub CLI (gh) not found or not in PATH.") - return False - - gh_auth = run_command(["gh", "auth", "status"], silent=True) - if not gh_auth: - print("Error: GitHub CLI not authenticated. Please run 'gh auth login'.") - return False - - return True - - -# Constants -PR_CACHE = {} # Cache for PR details to speed up multiple rounds referencing same PRs NAME_TO_LOGIN = {} # Map full names to GitHub logins for consolidation VERIFIED_LOGINS = set() # Track IDs known to be valid GitHub logins (vs free-form names) - -# Bots to exclude from contributor lists -BOT_NAMES = { - "Copilot", - "dependabot[bot]", - "app/dependabot", - "github-actions[bot]", - "app/copilot-swe-agent", - "CI Bot", - "github-advanced-security[bot]", - "GitHub Actions", - "dependabot", - "github-actions", - "Gemini", - "CI", -} - - -def is_bot(name): - if not name: - return True - name_clean = name.strip().lstrip("@") - # Known bots and patterns - if name_clean in BOT_NAMES: - return True - if "[bot]" in name_clean.lower(): - return True - if name_clean.lower().startswith("app/"): - return True - return False - - -def is_invalid(name): - if not name: - return True - # If it's a bot, it's considered a valid identity for the CSV - if is_bot(name): - return False - - name_clean = name.strip().lstrip("@") - # Paths, brackets, and code extensions - if "/" in name_clean or "\\" in name_clean or "[" in name_clean or "]" in name_clean: - return True - if any(name_clean.lower().endswith(ext) for ext in [".cmake", ".py", ".h", ".cc", ".cpp", ".txt", ".md"]): - return True - return False - - -def get_pr_number(subject): - match = re.search(r"\(#(\d+)\)$", subject.strip()) - if match: - return match.group(1) - return None +PR_CACHE = {} # Cache for PR details to speed up multiple rounds referencing same PRs def get_pr_details(pr_number): @@ -209,27 +128,51 @@ def extract_authors_from_commit(commit_id): return authors -def extract_pr_numbers(text, strict=False): - if not text: - return [] +# Bots to exclude from contributor lists +BOT_NAMES = { + "Copilot", + "dependabot[bot]", + "app/dependabot", + "github-actions[bot]", + "app/copilot-swe-agent", + "CI Bot", + "github-advanced-security[bot]", + "GitHub Actions", + "dependabot", + "github-actions", + "Gemini", + "CI", +} - if strict: - # Strict mode: Only look for (#123) with closing paren or full onnxruntime URLs - # This avoids noise from version numbers or external repo PRs - # And it avoids matching truncated headlines like (#25... as PR #25 - patterns = [ - r"\(#(\d+)\)", # (#123) - r"microsoft/onnxruntime/pull/(\d+)", - ] - results = [] - for p in patterns: - results.extend(re.findall(p, text)) - return [int(x) for x in set(results)] - # Matches patterns like #123 or https://github.com/microsoft/onnxruntime/pull/123 - # Also handles ( #123) or similar in titles - prs = re.findall(r"(?:#|/pull/)(\d+)", text) - return [int(x) for x in set(prs)] +def is_bot(name): + if not name: + return True + name_clean = name.strip().lstrip("@") + # Known bots and patterns + if name_clean in BOT_NAMES: + return True + if "[bot]" in name_clean.lower(): + return True + if name_clean.lower().startswith("app/"): + return True + return False + + +def is_invalid(name): + if not name: + return True + # If it's a bot, it's considered a valid identity for the CSV + if is_bot(name): + return False + + name_clean = name.strip().lstrip("@") + # Paths, brackets, and code extensions + if "/" in name_clean or "\\" in name_clean or "[" in name_clean or "]" in name_clean: + return True + if any(name_clean.lower().endswith(ext) for ext in [".cmake", ".py", ".h", ".cc", ".cpp", ".txt", ".md"]): + return True + return False def get_prs_from_log(log_output, prs_base=None, log_file=None, scan_depth=100): @@ -280,7 +223,7 @@ def get_prs_from_log(log_output, prs_base=None, log_file=None, scan_depth=100): # Reuse commits already fetched in get_pr_details to avoid an extra gh CLI call for commit in details.get("commits", []): all_extracted_nums.extend(extract_pr_numbers(commit.get("messageHeadline", ""), strict=True)) - all_extracted_nums.extend(extract_pr_numbers(commit.get("messageBody", ""), strict=True)) + # DO NOT scan messageBody for expansion to avoid historical context PRs # Filter and Normalize current_pr_int = int(pr_num_str)