Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 74 additions & 0 deletions .github/workflows/update-coverage-include.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
name: Update Coverage Include List

on:
push:
branches:
- amd-integration
paths:
- 'flashinfer/**/*.py'

permissions:
contents: write
pull-requests: write

jobs:
update-coverage:
runs-on: ubuntu-latest
steps:
- name: Checkout amd-integration branch
uses: actions/checkout@v4
with:
ref: amd-integration
fetch-depth: 0

- name: Fetch upstream main
run: |
git remote add upstream https://github.com/flashinfer-ai/flashinfer.git || true
git fetch upstream main

- name: Setup Python
uses: actions/setup-python@v5
with:
python-version: '3.12'

- name: Run update script
id: update
run: |
python3 scripts/update_coverage_include.py > update_log.txt 2>&1 || (cat update_log.txt && exit 1)
if git diff --quiet pyproject.toml; then
echo "changed=false" >> $GITHUB_OUTPUT
cat update_log.txt
else
echo "changed=true" >> $GITHUB_OUTPUT
cat update_log.txt
fi

- name: Create Pull Request
if: steps.update.outputs.changed == 'true'
uses: peter-evans/create-pull-request@v6
with:
token: ${{ secrets.GITHUB_TOKEN }}
commit-message: 'chore: Update coverage include list for AMD-modified files'
branch: update-coverage-include-${{ github.run_number }}
delete-branch: true
title: 'chore: Update coverage include list'
body: |
## Coverage Include List Update

This PR was automatically generated after changes were merged to `amd-integration`.

The coverage include list in `pyproject.toml` has been updated to reflect the current set of AMD/HIP modified files compared to upstream.

### Changes
- Updated `[tool.coverage.run].include` list based on `git merge-base upstream/main HEAD`

### Modified Files Count
Check the commit for the updated list of files.

---
*Automated by GitHub Actions*
labels: |
automated
coverage
chore
assignees: ${{ github.actor }}
2 changes: 2 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -229,3 +229,5 @@ pytest -v
```

The default test configuration is specified in [pyproject.toml](pyproject.toml) under the `testpaths` setting.

test test test
2 changes: 2 additions & 0 deletions flashinfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,8 @@ def _check_torch_rocm_compatibility():
3. PyTorch ROCm version matches system ROCm version (if detectable)

Provides helpful error messages to guide users to correct installation.

# FIXME: Test test
"""

# Check for torch package
Expand Down
6 changes: 5 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -220,8 +220,12 @@ python_classes = ["Test*"]
python_functions = ["test_*"]

[tool.coverage.run]
source = ["flashinfer"]
# Note: 'source' is commented out because 'include' would be ignored if source is set
# source = ["flashinfer"]
branch = true
# AMD/HIP modified files - auto-updated by GitHub Action
include = [
]

[tool.coverage.report]
show_missing = true
Expand Down
163 changes: 163 additions & 0 deletions scripts/update_coverage_include.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,163 @@
#!/usr/bin/env python3
"""
Update the coverage include list in pyproject.toml based on AMD/HIP modified files.

This script compares the current branch against the merge base with upstream/main
to find modified Python files in the flashinfer/ module and updates the coverage include list.
"""

import re
import subprocess
import sys
from pathlib import Path


def get_merge_base(upstream_ref="upstream/main"):
"""Get the merge base (fork point) between current branch and upstream."""
try:
result = subprocess.run(
["git", "merge-base", upstream_ref, "HEAD"],
capture_output=True,
text=True,
check=True,
)
merge_base = result.stdout.strip()
if not merge_base:
raise ValueError("Empty merge base")
return merge_base
except (subprocess.CalledProcessError, ValueError) as e:
print(
f"Error: Failed to find merge base with {upstream_ref}: {e}",
file=sys.stderr,
)
sys.exit(1)


def get_modified_files(upstream_ref="upstream/main"):
"""Get list of modified Python files in flashinfer/ module."""
# Find the fork point
merge_base = get_merge_base(upstream_ref)
print(f" Merge base: {merge_base[:8]}")

try:
result = subprocess.run(
["git", "diff", "--name-only", "--diff-filter=AM", merge_base, "HEAD"],
capture_output=True,
text=True,
check=True,
)
except subprocess.CalledProcessError as e:
print(f"Error: Failed to get git diff: {e}", file=sys.stderr)
sys.exit(1)

modified_files = [
line.strip()
for line in result.stdout.strip().split("\n")
if line.strip() and line.startswith("flashinfer/") and line.endswith(".py")
]

return sorted(modified_files)


def update_pyproject_toml(modified_files, dry_run=False):
"""Update the include list in pyproject.toml."""
pyproject_path = Path("pyproject.toml")

if not pyproject_path.exists():
print("Error: pyproject.toml not found", file=sys.stderr)
sys.exit(1)

# Read current content
with open(pyproject_path, "r") as f:
content = f.read()

# Build the new include list
include_lines = ["include = ["]
for file in modified_files:
include_lines.append(f' "{file}",')
include_lines.append("]")
new_include = "\n".join(include_lines)

# Pattern to match the include section
# Matches from "include = [" to the closing "]"
pattern = r"(# AMD/HIP modified files.*\n)include = \[[^\]]*\]"

# Check if pattern exists
if not re.search(pattern, content, flags=re.DOTALL):
print(
"Error: Could not find coverage include section in pyproject.toml",
file=sys.stderr,
)
print(
"Make sure the file has the marker comment: '# AMD/HIP modified files'",
file=sys.stderr,
)
sys.exit(1)

# Replace the include section
new_content = re.sub(pattern, r"\1" + new_include, content, flags=re.DOTALL)

# Check if content changed
if new_content == content:
print("✓ Coverage include list is already up to date")
return False

if dry_run:
print("Would update pyproject.toml with the following files:")
for file in modified_files:
print(f" - {file}")
return True

# Write updated content
with open(pyproject_path, "w") as f:
f.write(new_content)

print("✓ Updated coverage include list in pyproject.toml")
print(f" Modified files: {len(modified_files)}")
for file in modified_files:
print(f" - {file}")

return True


def main():
import argparse

parser = argparse.ArgumentParser(
description="Update coverage include list based on AMD/HIP modified files"
)
parser.add_argument(
"--upstream",
default="upstream/main",
help="Upstream reference to compare against (default: upstream/main)",
)
parser.add_argument(
"--dry-run",
action="store_true",
help="Show what would be changed without modifying files",
)

args = parser.parse_args()

print(f"Comparing against {args.upstream}...")
modified_files = get_modified_files(args.upstream)

if not modified_files:
print("No modified Python files found in flashinfer/ module")
sys.exit(0)

print(f"Found {len(modified_files)} modified Python files in flashinfer/")

changed = update_pyproject_toml(modified_files, dry_run=args.dry_run)

if changed and not args.dry_run:
print("\nNext steps:")
print(" 1. Review changes: git diff pyproject.toml")
print(" 2. Test coverage: pytest --cov --cov-report=term-missing")
print(
" 3. Commit: git add pyproject.toml && git commit -m 'Update coverage include list'"
)


if __name__ == "__main__":
main()