diff --git a/.github/workflows/agent-currency-fix.yml b/.github/workflows/agent-currency-fix.yml index 725b8f151c0c..f58860acbf73 100644 --- a/.github/workflows/agent-currency-fix.yml +++ b/.github/workflows/agent-currency-fix.yml @@ -139,18 +139,6 @@ jobs: Agent run: ${{ github.server_url }}/${{ github.repository }}/actions/runs/${{ github.run_id }}" fi - - name: Download failed run logs - if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true' - env: - GH_TOKEN: ${{ steps.app-token.outputs.token }} - FAILED_RUNS: ${{ steps.failures.outputs.failed_runs }} - run: | - mkdir -p /tmp/ci-logs - for RUN in $FAILED_RUNS; do - gh api "/repos/${{ github.repository }}/actions/runs/${RUN}/logs" > "/tmp/ci-logs-${RUN}.zip" || true - [ -s "/tmp/ci-logs-${RUN}.zip" ] && unzip -o "/tmp/ci-logs-${RUN}.zip" -d "/tmp/ci-logs/run-${RUN}/" || true - done - - name: Determine framework if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true' run: | @@ -162,12 +150,15 @@ jobs: id: fix env: AWS_REGION: us-west-2 + GH_TOKEN: ${{ steps.app-token.outputs.token }} run: | python3 -m pip install boto3 -q python3 /tmp/agent-fix.py \ - --logs-dir /tmp/ci-logs/ \ --framework "$FRAMEWORK" \ - --branch "$HEAD_BRANCH" + --branch "$HEAD_BRANCH" \ + --run-ids "${{ steps.failures.outputs.failed_runs }}" \ + --token "$GH_TOKEN" \ + --repo "${{ github.repository }}" - name: Commit and push if: steps.failures.outputs.has_failures == 'true' && steps.retry.outputs.max_reached != 'true' && steps.fix.outcome == 'success' diff --git a/scripts/autocurrency/agent-fix.py b/scripts/autocurrency/agent-fix.py index 32e4848a7186..88dd6ef5155e 100755 --- a/scripts/autocurrency/agent-fix.py +++ b/scripts/autocurrency/agent-fix.py @@ -34,6 +34,7 @@ - ONLY fix the specific failure shown in the logs - Do NOT delete or skip tests - Do NOT modify files unrelated to the failure +- ONLY edit files that are provided in the context below. If a file is not shown, do not edit it. - For CVE scan failures: pin a safe version in Dockerfile, or add to allowlist if vendored/unpatchable - For "file not found" errors: find the new path in the upstream repo - For build errors: check if upstream base image changed something @@ -43,30 +44,123 @@ If the failure is TRANSIENT (capacity, timeout, runner crash), respond with exactly: TRANSIENT: -Otherwise, respond with search/replace blocks: +Otherwise, respond with search/replace blocks. Use this EXACT format: - +path/to/file.ext <<<<<<< SEARCH - +exact text to find in the file ======= - +replacement text >>>>>>> REPLACE +IMPORTANT: Write the file path as plain text (e.g., docker/vllm/Dockerfile). Do NOT wrap it in angle brackets, backticks, or any other formatting. + Include 1-2 surrounding lines in SEARCH for unique anchoring. For JSON arrays (allowlists), SEARCH the last few lines and REPLACE with those lines plus the new entry. -End with: DESCRIPTION: """ +End with: DESCRIPTION: one-line commit message""" def parse_args(): p = argparse.ArgumentParser() - p.add_argument("--logs-dir", required=True) p.add_argument("--framework", required=True) p.add_argument("--branch", required=True) + p.add_argument("--run-ids", default="", help="Space-separated failed run IDs") + p.add_argument("--token", default=os.environ.get("GH_TOKEN", ""), help="GitHub token") + p.add_argument("--repo", default="aws/deep-learning-containers") return p.parse_args() -def extract_error_lines(logs_dir: str) -> str: +def extract_failure_info(run_ids: str, token: str, repo: str) -> tuple: + """Use GitHub API to get structured failure info. Returns (error_text, failed_job_names).""" + print("Using GitHub API for structured failure extraction") + import urllib.request + + results = [] + failed_job_names = [] + for run_id in run_ids.strip().split(): + if not run_id: + continue + # Get jobs for this run + url = f"https://api.github.com/repos/{repo}/actions/runs/{run_id}/jobs?per_page=100" + req = urllib.request.Request( + url, + headers={ + "Authorization": f"token {token}", + "Accept": "application/vnd.github+json", + }, + ) + try: + resp = urllib.request.urlopen(req) + data = json.loads(resp.read()) + except Exception as e: + results.append(f"Failed to fetch jobs for run {run_id}: {e}") + continue + + # Find failed jobs and steps + tracked_jobs = [ + "build-image", + "sanity-test", + "security-test", + "telemetry-test", + "upstream-tests", + "sagemaker-test", + ] + for job in data.get("jobs", []): + if job.get("conclusion") != "failure": + continue + + # Only process jobs that match our tracked job names + job_lower = job["name"].lower() + matched_key = None + for key in tracked_jobs: + if key.replace("-", "") in job_lower.replace("-", "").replace(" ", ""): + matched_key = key + break + if not matched_key: + continue + + failed_steps = [ + s["name"] for s in job.get("steps", []) if s.get("conclusion") == "failure" + ] + results.append(f"FAILED JOB: {job['name']}") + failed_job_names.append(matched_key) + results.append(f" Failed steps: {', '.join(failed_steps)}") + + # Download log from run zip + import io + import zipfile + + zip_url = f"https://api.github.com/repos/{repo}/actions/runs/{run_id}/logs" + zip_req = urllib.request.Request( + zip_url, + headers={ + "Authorization": f"token {token}", + "Accept": "application/vnd.github+json", + }, + ) + try: + resp = urllib.request.urlopen(zip_req) + z = zipfile.ZipFile(io.BytesIO(resp.read())) + target = job["name"].replace(" / ", " _ ") + for name in z.namelist(): + if target in name: + log_lines = z.read(name).decode(errors="replace").splitlines() + results.append(f" Log ({name}, {len(log_lines)} lines):") + results.extend(f" {line}" for line in log_lines) + break + else: + results.append(f" No matching log file for '{target}' in zip") + except Exception as e: + results.append(f" Failed to download logs: {e}") + + results.append("") + + return "\n".join(results) or "No failure info extracted.", failed_job_names + + +def _extract_via_grep(logs_dir: str) -> str: + """Fallback: grep log files for error keywords.""" logs_path = Path(logs_dir) if not logs_path.exists(): return "No logs available." @@ -82,7 +176,7 @@ def extract_error_lines(logs_dir: str) -> str: for i, line in enumerate(lines): if any(kw in line.lower() for kw in keywords): start, end = max(0, i - 2), min(len(lines), i + 3) - error_lines.append(f"--- {log_file.name}:{i+1} ---") + error_lines.append(f"--- {log_file.name}:{i + 1} ---") error_lines.extend(lines[start:end]) error_lines.append("") if len(error_lines) > MAX_LOG_LINES: @@ -107,7 +201,14 @@ def detect_failed_jobs(logs_dir: str) -> list: job_names = set() for f in logs_path.rglob("*.txt"): name = f.stem.lower() - for job in ["build-image", "sanity-test", "security-test", "telemetry-test", "upstream-tests", "sagemaker-test"]: + for job in [ + "build-image", + "sanity-test", + "security-test", + "telemetry-test", + "upstream-tests", + "sagemaker-test", + ]: if job in name: job_names.add(job) return list(job_names) @@ -120,15 +221,20 @@ def load_context_files(framework: str, failed_jobs: list) -> dict: """ mapping_path = Path(CONTEXT_MAP_PATH) if not mapping_path.exists(): - return {p: read_file(p) for p in [ - f"docker/{framework}/Dockerfile", - f".github/config/image/{framework}-ec2.yml", - f"test/security/data/ecr_scan_allowlist/{framework}/framework_allowlist.json", - ] if read_file(p)} + return { + p: read_file(p) + for p in [ + f"docker/{framework}/Dockerfile", + f".github/config/image/{framework}-ec2.yml", + f"test/security/data/ecr_scan_allowlist/{framework}/framework_allowlist.json", + ] + if read_file(p) + } # Parse YAML via subprocess (yq available on runners) or fallback to simple parsing try: import yaml + config = yaml.safe_load(mapping_path.read_text()) except ImportError: # Fallback: parse the simple YAML structure manually @@ -166,7 +272,12 @@ def _parse_simple_yaml(text: str) -> dict: current_job = None elif line == "jobs:": current_section = "jobs" - elif current_section == "jobs" and line.startswith(" ") and not line.startswith(" ") and stripped.endswith(":"): + elif ( + current_section == "jobs" + and line.startswith(" ") + and not line.startswith(" ") + and stripped.endswith(":") + ): current_job = stripped.rstrip(":") result["jobs"][current_job] = [] elif stripped.startswith("- "): @@ -182,7 +293,9 @@ def get_previous_fixes() -> str: try: r = subprocess.run( ["git", "log", "--oneline", "origin/main..HEAD", "--grep=[agent-fix]"], - capture_output=True, text=True, check=True, + capture_output=True, + text=True, + check=True, ) return r.stdout.strip() or "None" except subprocess.CalledProcessError: @@ -193,8 +306,9 @@ def parse_blocks(response: str) -> list: blocks = [] for m in SEARCH_REPLACE_PATTERN.finditer(response): filepath = m.group(1).strip().strip("`").strip() - # Strip common LLM artifacts: , **filepath**, `filepath` - filepath = re.sub(r"^<\w+>|<\/\w+>$", "", filepath).strip() + # Strip all common LLM artifacts: path, , **path**, `path` + filepath = re.sub(r"^<[^>]*>", "", filepath).strip() # strips , , etc. + filepath = re.sub(r"^<|>$", "", filepath).strip() # strips bare < > filepath = filepath.strip("*").strip("`").strip() blocks.append({"path": filepath, "search": m.group(2), "replace": m.group(3)}) return blocks @@ -207,14 +321,18 @@ def find_match(content: str, search: str) -> tuple: return idx, idx + len(search) # Whitespace-normalized: strip trailing spaces per line - norm = lambda s: "\n".join(line.rstrip() for line in s.splitlines()) + def norm(s): + return "\n".join(line.rstrip() for line in s.splitlines()) + norm_content, norm_search = norm(content), norm(search) idx = norm_content.find(norm_search) if idx != -1: line_num = norm_content[:idx].count("\n") lines = content.splitlines(keepends=True) end_line = line_num + norm_search.count("\n") - return sum(len(lines[i]) for i in range(line_num)), sum(len(lines[i]) for i in range(end_line + 1)) + return sum(len(lines[i]) for i in range(line_num)), sum( + len(lines[i]) for i in range(end_line + 1) + ) return None, None @@ -256,18 +374,19 @@ def call_bedrock(system: str, user: str) -> str: client = boto3.client("bedrock-runtime", region_name=REGION) resp = client.invoke_model( modelId=MODEL_ID, - body=json.dumps({ - "anthropic_version": "bedrock-2023-05-31", - "max_tokens": MAX_TOKENS, - "system": system, - "messages": [{"role": "user", "content": user}], - }), + body=json.dumps( + { + "anthropic_version": "bedrock-2023-05-31", + "max_tokens": MAX_TOKENS, + "system": system, + "messages": [{"role": "user", "content": user}], + } + ), ) return json.loads(resp["body"].read())["content"][0]["text"] -def build_prompt(framework, branch, error_lines, context_files, - previous_fixes, retry_context=""): +def build_prompt(framework, branch, error_lines, context_files, previous_fixes, retry_context=""): files_section = "" for path, content in context_files.items(): ext = Path(path).suffix.lstrip(".") @@ -295,21 +414,12 @@ def main(): args = parse_args() print(f"=== Currency Fix Agent: {args.framework} @ {args.branch} ===\n") - error_lines = extract_error_lines(args.logs_dir) - failed_jobs = detect_failed_jobs(args.logs_dir) + error_lines, api_failed_jobs = extract_failure_info(args.run_ids, args.token, args.repo) + # Use API-detected jobs if available, otherwise fall back to log filename detection + failed_jobs = api_failed_jobs context_files = load_context_files(args.framework, failed_jobs) previous_fixes = get_previous_fixes() - # Debug: show what logs we have - logs_path = Path(args.logs_dir) - if logs_path.exists(): - log_files = list(logs_path.rglob("*.txt")) - print(f"Log files found: {len(log_files)}") - for f in log_files[:10]: - print(f" {f.name} ({f.stat().st_size} bytes)") - else: - print(f"WARNING: logs dir {args.logs_dir} does not exist!") - print(f"Error lines extracted: {len(error_lines.splitlines())} lines") print(f"Error lines preview: {error_lines[:500]}") print(f"Failed jobs detected: {failed_jobs or 'none (including all files)'}") @@ -320,8 +430,9 @@ def main(): for attempt in range(1, MAX_LLM_RETRIES + 1): print(f"--- Attempt {attempt}/{MAX_LLM_RETRIES} ---") - prompt = build_prompt(args.framework, args.branch, error_lines, - context_files, previous_fixes, retry_context) + prompt = build_prompt( + args.framework, args.branch, error_lines, context_files, previous_fixes, retry_context + ) print(f"Prompt size: {len(prompt)} chars") response = call_bedrock(SYSTEM_PROMPT, prompt) print(f"LLM response ({len(response)} chars):") @@ -336,7 +447,8 @@ def main(): blocks = parse_blocks(response) if blocks: - print(f"Parsed {len(blocks)} block(s): {[b["path"] for b in blocks]}") + paths = [b["path"] for b in blocks] + print(f"Parsed {len(blocks)} block(s): {paths}") if not blocks: retry_context = ( f"Could not parse search/replace blocks from response.\n"