diff --git a/.github/workflows/update-coverage-include.yml b/.github/workflows/update-coverage-include.yml new file mode 100644 index 0000000000..d22c97b70a --- /dev/null +++ b/.github/workflows/update-coverage-include.yml @@ -0,0 +1,73 @@ +name: Update Coverage Include List + +on: + push: + branches: + - amd-integration + paths: + - 'flashinfer/**/*.py' + workflow_dispatch: + +permissions: + contents: write + pull-requests: write + +jobs: + update-coverage: + runs-on: ubuntu-latest + steps: + - name: Checkout amd-integration branch + uses: actions/checkout@v4 + with: + ref: amd-integration + fetch-depth: 0 + + - name: Fetch upstream main + run: | + git remote add upstream https://github.com/flashinfer-ai/flashinfer.git || true + git fetch upstream main + + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: '3.12' + + - name: Run update script + id: update + run: | + python3 scripts/update_coverage_include.py + if git diff --quiet pyproject.toml; then + echo "changed=false" >> $GITHUB_OUTPUT + else + echo "changed=true" >> $GITHUB_OUTPUT + fi + + - name: Create Pull Request + if: steps.update.outputs.changed == 'true' + uses: peter-evans/create-pull-request@v6 + with: + token: ${{ secrets.GITHUB_TOKEN }} + commit-message: 'chore: Update coverage include list for AMD-modified files' + branch: update-coverage-include-${{ github.run_number }} + delete-branch: true + title: 'chore: Update coverage include list' + body: | + ## Coverage Include List Update + + This PR was automatically generated after changes were merged to `amd-integration`. + + The coverage include list in `pyproject.toml` has been updated to reflect the current set of AMD/HIP modified files compared to upstream. + + ### Changes + - Updated `[tool.coverage.run].include` list based on `git merge-base upstream/main HEAD` + + ### Modified Files Count + Check the commit for the updated list of files. + + --- + *Automated by GitHub Actions* + labels: | + automated + coverage + chore + assignees: ${{ github.actor }} diff --git a/pyproject.toml b/pyproject.toml index 0e7488717c..018785069a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -220,8 +220,12 @@ python_classes = ["Test*"] python_functions = ["test_*"] [tool.coverage.run] -source = ["flashinfer"] +# Note: 'source' is commented out because 'include' would be ignored if source is set +# source = ["flashinfer"] branch = true +# AMD/HIP modified files - auto-updated by GitHub Action +include = [ +] [tool.coverage.report] show_missing = true diff --git a/scripts/update_coverage_include.py b/scripts/update_coverage_include.py new file mode 100755 index 0000000000..8abeda77aa --- /dev/null +++ b/scripts/update_coverage_include.py @@ -0,0 +1,163 @@ +#!/usr/bin/env python3 +""" +Update the coverage include list in pyproject.toml based on AMD/HIP modified files. + +This script compares the current branch against the merge base with upstream/main +to find modified Python files in the flashinfer/ module and updates the coverage include list. +""" + +import re +import subprocess +import sys +from pathlib import Path + + +def get_merge_base(upstream_ref="upstream/main"): + """Get the merge base (fork point) between current branch and upstream.""" + try: + result = subprocess.run( + ["git", "merge-base", upstream_ref, "HEAD"], + capture_output=True, + text=True, + check=True, + ) + merge_base = result.stdout.strip() + if not merge_base: + raise ValueError("Empty merge base") + return merge_base + except (subprocess.CalledProcessError, ValueError) as e: + print( + f"Error: Failed to find merge base with {upstream_ref}: {e}", + file=sys.stderr, + ) + sys.exit(1) + + +def get_modified_files(upstream_ref="upstream/main"): + """Get list of modified Python files in flashinfer/ module.""" + # Find the fork point + merge_base = get_merge_base(upstream_ref) + print(f" Merge base: {merge_base[:8]}") + + try: + result = subprocess.run( + ["git", "diff", "--name-only", "--diff-filter=AM", merge_base, "HEAD"], + capture_output=True, + text=True, + check=True, + ) + except subprocess.CalledProcessError as e: + print(f"Error: Failed to get git diff: {e}", file=sys.stderr) + sys.exit(1) + + modified_files = [ + line.strip() + for line in result.stdout.strip().split("\n") + if line.strip() and line.startswith("flashinfer/") and line.endswith(".py") + ] + + return sorted(modified_files) + + +def update_pyproject_toml(modified_files, dry_run=False): + """Update the include list in pyproject.toml.""" + pyproject_path = Path("pyproject.toml") + + if not pyproject_path.exists(): + print("Error: pyproject.toml not found", file=sys.stderr) + sys.exit(1) + + # Read current content + with open(pyproject_path, "r") as f: + content = f.read() + + # Build the new include list + include_lines = ["include = ["] + for file in modified_files: + include_lines.append(f' "{file}",') + include_lines.append("]") + new_include = "\n".join(include_lines) + + # Pattern to match the include section + # Matches from "include = [" to the closing "]" + pattern = r"(# AMD/HIP modified files.*\n)include = \[[^\]]*\]" + + # Check if pattern exists + if not re.search(pattern, content, flags=re.DOTALL): + print( + "Error: Could not find coverage include section in pyproject.toml", + file=sys.stderr, + ) + print( + "Make sure the file has the marker comment: '# AMD/HIP modified files'", + file=sys.stderr, + ) + sys.exit(1) + + # Replace the include section + new_content = re.sub(pattern, r"\1" + new_include, content, flags=re.DOTALL) + + # Check if content changed + if new_content == content: + print("✓ Coverage include list is already up to date") + return False + + if dry_run: + print("Would update pyproject.toml with the following files:") + for file in modified_files: + print(f" - {file}") + return True + + # Write updated content + with open(pyproject_path, "w") as f: + f.write(new_content) + + print("✓ Updated coverage include list in pyproject.toml") + print(f" Modified files: {len(modified_files)}") + for file in modified_files: + print(f" - {file}") + + return True + + +def main(): + import argparse + + parser = argparse.ArgumentParser( + description="Update coverage include list based on AMD/HIP modified files" + ) + parser.add_argument( + "--upstream", + default="upstream/main", + help="Upstream reference to compare against (default: upstream/main)", + ) + parser.add_argument( + "--dry-run", + action="store_true", + help="Show what would be changed without modifying files", + ) + + args = parser.parse_args() + + print(f"Comparing against {args.upstream}...") + modified_files = get_modified_files(args.upstream) + + if not modified_files: + print("No modified Python files found in flashinfer/ module") + sys.exit(0) + + print(f"Found {len(modified_files)} modified Python files in flashinfer/") + + changed = update_pyproject_toml(modified_files, dry_run=args.dry_run) + + if changed and not args.dry_run: + print("\nNext steps:") + print(" 1. Review changes: git diff pyproject.toml") + print(" 2. Test coverage: pytest --cov --cov-report=term-missing") + print( + " 3. Commit: git add pyproject.toml && git commit -m 'Update coverage include list'" + ) + + +if __name__ == "__main__": + main()