Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 5 additions & 14 deletions .github/workflows/agent-currency-fix.yml
Original file line number Diff line number Diff line change
Expand Up @@ -139,18 +139,6 @@ jobs:
Agent run: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}"
fi

- name: Download failed run logs
if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true'
env:
GH_TOKEN: ${{ steps.app-token.outputs.token }}
FAILED_RUNS: ${{ steps.failures.outputs.failed_runs }}
run: |
mkdir -p /tmp/ci-logs
for RUN in $FAILED_RUNS; do
gh api "/repos/${{ github.repository }}/actions/runs/${RUN}/logs" > "/tmp/ci-logs-${RUN}.zip" || true
[ -s "/tmp/ci-logs-${RUN}.zip" ] && unzip -o "/tmp/ci-logs-${RUN}.zip" -d "/tmp/ci-logs/run-${RUN}/" || true
done

- name: Determine framework
if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true'
run: |
Expand All @@ -162,12 +150,15 @@ jobs:
id: fix
env:
AWS_REGION: us-west-2
GH_TOKEN: ${{ steps.app-token.outputs.token }}
run: |
python3 -m pip install boto3 -q
python3 /tmp/agent-fix.py \
--logs-dir /tmp/ci-logs/ \
--framework "$FRAMEWORK" \
--branch "$HEAD_BRANCH"
--branch "$HEAD_BRANCH" \
--run-ids "${{ steps.failures.outputs.failed_runs }}" \
--token "$GH_TOKEN" \
--repo "${{ github.repository }}"

- name: Commit and push
if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true' && steps.fix.outcome == 'success'
Expand Down
198 changes: 155 additions & 43 deletions scripts/autocurrency/agent-fix.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
- ONLY fix the specific failure shown in the logs
- Do NOT delete or skip tests
- Do NOT modify files unrelated to the failure
- ONLY edit files that are provided in the context below. If a file is not shown, do not edit it.
- For CVE scan failures: pin a safe version in Dockerfile, or add to allowlist if vendored/unpatchable
- For "file not found" errors: find the new path in the upstream repo
- For build errors: check if upstream base image changed something
Expand All @@ -43,30 +44,123 @@
If the failure is TRANSIENT (capacity, timeout, runner crash), respond with exactly:
TRANSIENT: <brief reason>

Otherwise, respond with search/replace blocks:
Otherwise, respond with search/replace blocks. Use this EXACT format:

<filepath>
path/to/file.ext
<<<<<<< SEARCH
<exact text to find in the file>
exact text to find in the file
=======
<replacement text>
replacement text
>>>>>>> REPLACE

IMPORTANT: Write the file path as plain text (e.g., docker/vllm/Dockerfile). Do NOT wrap it in angle brackets, backticks, or any other formatting.

Include 1-2 surrounding lines in SEARCH for unique anchoring.
For JSON arrays (allowlists), SEARCH the last few lines and REPLACE with those lines plus the new entry.

End with: DESCRIPTION: <one-line commit message>"""
End with: DESCRIPTION: one-line commit message"""


def parse_args():
p = argparse.ArgumentParser()
p.add_argument("--logs-dir", required=True)
p.add_argument("--framework", required=True)
p.add_argument("--branch", required=True)
p.add_argument("--run-ids", default="", help="Space-separated failed run IDs")
p.add_argument("--token", default=os.environ.get("GH_TOKEN", ""), help="GitHub token")
p.add_argument("--repo", default="aws/deep-learning-containers")
return p.parse_args()


def extract_error_lines(logs_dir: str) -> str:
def extract_failure_info(run_ids: str, token: str, repo: str) -> tuple:
"""Use GitHub API to get structured failure info. Returns (error_text, failed_job_names)."""
print("Using GitHub API for structured failure extraction")
import urllib.request

results = []
failed_job_names = []
for run_id in run_ids.strip().split():
if not run_id:
continue
# Get jobs for this run
url = f"https://api.github.com/repos/{repo}/actions/runs/{run_id}/jobs?per_page=100"
req = urllib.request.Request(
url,
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
},
)
try:
resp = urllib.request.urlopen(req)
data = json.loads(resp.read())
except Exception as e:
results.append(f"Failed to fetch jobs for run {run_id}: {e}")
continue

# Find failed jobs and steps
tracked_jobs = [
"build-image",
"sanity-test",
"security-test",
"telemetry-test",
"upstream-tests",
"sagemaker-test",
]
for job in data.get("jobs", []):
if job.get("conclusion") != "failure":
continue

# Only process jobs that match our tracked job names
job_lower = job["name"].lower()
matched_key = None
for key in tracked_jobs:
if key.replace("-", "") in job_lower.replace("-", "").replace(" ", ""):
matched_key = key
break
if not matched_key:
continue

failed_steps = [
s["name"] for s in job.get("steps", []) if s.get("conclusion") == "failure"
]
results.append(f"FAILED JOB: {job['name']}")
failed_job_names.append(matched_key)
results.append(f" Failed steps: {', '.join(failed_steps)}")

# Download log from run zip
import io
import zipfile

zip_url = f"https://api.github.com/repos/{repo}/actions/runs/{run_id}/logs"
zip_req = urllib.request.Request(
zip_url,
headers={
"Authorization": f"token {token}",
"Accept": "application/vnd.github+json",
},
)
try:
resp = urllib.request.urlopen(zip_req)
z = zipfile.ZipFile(io.BytesIO(resp.read()))
target = job["name"].replace(" / ", " _ ")
for name in z.namelist():
if target in name:
log_lines = z.read(name).decode(errors="replace").splitlines()
results.append(f" Log ({name}, {len(log_lines)} lines):")
results.extend(f" {line}" for line in log_lines)
break
else:
results.append(f" No matching log file for '{target}' in zip")
except Exception as e:
results.append(f" Failed to download logs: {e}")

results.append("")

return "\n".join(results) or "No failure info extracted.", failed_job_names


def _extract_via_grep(logs_dir: str) -> str:
"""Fallback: grep log files for error keywords."""
logs_path = Path(logs_dir)
if not logs_path.exists():
return "No logs available."
Expand All @@ -82,7 +176,7 @@ def extract_error_lines(logs_dir: str) -> str:
for i, line in enumerate(lines):
if any(kw in line.lower() for kw in keywords):
start, end = max(0, i - 2), min(len(lines), i + 3)
error_lines.append(f"--- {log_file.name}:{i+1} ---")
error_lines.append(f"--- {log_file.name}:{i + 1} ---")
error_lines.extend(lines[start:end])
error_lines.append("")
if len(error_lines) > MAX_LOG_LINES:
Expand All @@ -107,7 +201,14 @@ def detect_failed_jobs(logs_dir: str) -> list:
job_names = set()
for f in logs_path.rglob("*.txt"):
name = f.stem.lower()
for job in ["build-image", "sanity-test", "security-test", "telemetry-test", "upstream-tests", "sagemaker-test"]:
for job in [
"build-image",
"sanity-test",
"security-test",
"telemetry-test",
"upstream-tests",
"sagemaker-test",
]:
if job in name:
job_names.add(job)
return list(job_names)
Expand All @@ -120,15 +221,20 @@ def load_context_files(framework: str, failed_jobs: list) -> dict:
"""
mapping_path = Path(CONTEXT_MAP_PATH)
if not mapping_path.exists():
return {p: read_file(p) for p in [
f"docker/{framework}/Dockerfile",
f".github/config/image/{framework}-ec2.yml",
f"test/security/data/ecr_scan_allowlist/{framework}/framework_allowlist.json",
] if read_file(p)}
return {
p: read_file(p)
for p in [
f"docker/{framework}/Dockerfile",
f".github/config/image/{framework}-ec2.yml",
f"test/security/data/ecr_scan_allowlist/{framework}/framework_allowlist.json",
]
if read_file(p)
}

# Parse YAML via subprocess (yq available on runners) or fallback to simple parsing
try:
import yaml

config = yaml.safe_load(mapping_path.read_text())
except ImportError:
# Fallback: parse the simple YAML structure manually
Expand Down Expand Up @@ -166,7 +272,12 @@ def _parse_simple_yaml(text: str) -> dict:
current_job = None
elif line == "jobs:":
current_section = "jobs"
elif current_section == "jobs" and line.startswith(" ") and not line.startswith(" ") and stripped.endswith(":"):
elif (
current_section == "jobs"
and line.startswith(" ")
and not line.startswith(" ")
and stripped.endswith(":")
):
current_job = stripped.rstrip(":")
result["jobs"][current_job] = []
elif stripped.startswith("- "):
Expand All @@ -182,7 +293,9 @@ def get_previous_fixes() -> str:
try:
r = subprocess.run(
["git", "log", "--oneline", "origin/main..HEAD", "--grep=[agent-fix]"],
capture_output=True, text=True, check=True,
capture_output=True,
text=True,
check=True,
)
return r.stdout.strip() or "None"
except subprocess.CalledProcessError:
Expand All @@ -193,8 +306,9 @@ def parse_blocks(response: str) -> list:
blocks = []
for m in SEARCH_REPLACE_PATTERN.finditer(response):
filepath = m.group(1).strip().strip("`").strip()
# Strip common LLM artifacts: <filepath>, **filepath**, `filepath`
filepath = re.sub(r"^<\w+>|<\/\w+>$", "", filepath).strip()
# Strip all common LLM artifacts: <filepath>path, <path>, **path**, `path`
filepath = re.sub(r"^<[^>]*>", "", filepath).strip() # strips <filepath>, <file>, etc.
filepath = re.sub(r"^<|>$", "", filepath).strip() # strips bare < >
filepath = filepath.strip("*").strip("`").strip()
blocks.append({"path": filepath, "search": m.group(2), "replace": m.group(3)})
return blocks
Expand All @@ -207,14 +321,18 @@ def find_match(content: str, search: str) -> tuple:
return idx, idx + len(search)

# Whitespace-normalized: strip trailing spaces per line
norm = lambda s: "\n".join(line.rstrip() for line in s.splitlines())
def norm(s):
return "\n".join(line.rstrip() for line in s.splitlines())

norm_content, norm_search = norm(content), norm(search)
idx = norm_content.find(norm_search)
if idx != -1:
line_num = norm_content[:idx].count("\n")
lines = content.splitlines(keepends=True)
end_line = line_num + norm_search.count("\n")
return sum(len(lines[i]) for i in range(line_num)), sum(len(lines[i]) for i in range(end_line + 1))
return sum(len(lines[i]) for i in range(line_num)), sum(
len(lines[i]) for i in range(end_line + 1)
)

return None, None

Expand Down Expand Up @@ -256,18 +374,19 @@ def call_bedrock(system: str, user: str) -> str:
client = boto3.client("bedrock-runtime", region_name=REGION)
resp = client.invoke_model(
modelId=MODEL_ID,
body=json.dumps({
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": MAX_TOKENS,
"system": system,
"messages": [{"role": "user", "content": user}],
}),
body=json.dumps(
{
"anthropic_version": "bedrock-2023-05-31",
"max_tokens": MAX_TOKENS,
"system": system,
"messages": [{"role": "user", "content": user}],
}
),
)
return json.loads(resp["body"].read())["content"][0]["text"]


def build_prompt(framework, branch, error_lines, context_files,
previous_fixes, retry_context=""):
def build_prompt(framework, branch, error_lines, context_files, previous_fixes, retry_context=""):
files_section = ""
for path, content in context_files.items():
ext = Path(path).suffix.lstrip(".")
Expand Down Expand Up @@ -295,21 +414,12 @@ def main():
args = parse_args()
print(f"=== Currency Fix Agent: {args.framework} @ {args.branch} ===\n")

error_lines = extract_error_lines(args.logs_dir)
failed_jobs = detect_failed_jobs(args.logs_dir)
error_lines, api_failed_jobs = extract_failure_info(args.run_ids, args.token, args.repo)
# Use API-detected jobs if available, otherwise fall back to log filename detection
failed_jobs = api_failed_jobs
context_files = load_context_files(args.framework, failed_jobs)
previous_fixes = get_previous_fixes()

# Debug: show what logs we have
logs_path = Path(args.logs_dir)
if logs_path.exists():
log_files = list(logs_path.rglob("*.txt"))
print(f"Log files found: {len(log_files)}")
for f in log_files[:10]:
print(f" {f.name} ({f.stat().st_size} bytes)")
else:
print(f"WARNING: logs dir {args.logs_dir} does not exist!")

print(f"Error lines extracted: {len(error_lines.splitlines())} lines")
print(f"Error lines preview: {error_lines[:500]}")
print(f"Failed jobs detected: {failed_jobs or 'none (including all files)'}")
Expand All @@ -320,8 +430,9 @@ def main():
for attempt in range(1, MAX_LLM_RETRIES + 1):
print(f"--- Attempt {attempt}/{MAX_LLM_RETRIES} ---")

prompt = build_prompt(args.framework, args.branch, error_lines,
context_files, previous_fixes, retry_context)
prompt = build_prompt(
args.framework, args.branch, error_lines, context_files, previous_fixes, retry_context
)
print(f"Prompt size: {len(prompt)} chars")
response = call_bedrock(SYSTEM_PROMPT, prompt)
print(f"LLM response ({len(response)} chars):")
Expand All @@ -336,7 +447,8 @@ def main():

blocks = parse_blocks(response)
if blocks:
print(f"Parsed {len(blocks)} block(s): {[b["path"] for b in blocks]}")
paths = [b["path"] for b in blocks]
print(f"Parsed {len(blocks)} block(s): {paths}")
if not blocks:
retry_context = (
f"Could not parse search/replace blocks from response.\n"
Expand Down
Loading