diff --git a/scripts/scrape_comments.py b/scripts/scrape_comments.py index 57fc3a2..50a1110 100644 --- a/scripts/scrape_comments.py +++ b/scripts/scrape_comments.py @@ -12,6 +12,14 @@ 4. Non-allowed users get a whitelist notice. Unsupported benchmarks get a supported-list reply. 5. Queue requests reply with a markdown table of pending jobs. +Environment variables can be passed on lines following the trigger: + run benchmark tpch_mem + DATAFUSION_RUNTIME_MEMORY_LIMIT=1G + OTHER_VAR=value + +Env var names must be uppercase/underscore (A-Z_), values alphanumeric with ._- allowed. +They are exported in the generated job script before the benchmark commands. + Repo-specific behavior: - apache/datafusion: - Standard benchmarks (bench.sh): ALLOWED_BENCHMARKS below; command: gh_compare_branch.sh @@ -30,7 +38,7 @@ import re import subprocess import sys -from dataclasses import dataclass +from dataclasses import dataclass, field from datetime import datetime, timedelta, timezone from json import loads from shutil import which @@ -216,6 +224,7 @@ def parse_job_metadata(path: str) -> tuple[str, str, str]: user = "unknown" comment = "" benchmarks: list[str] = [] + env_vars: list[str] = [] try: with open(path, "r") as f: for line in f: @@ -232,10 +241,14 @@ def parse_job_metadata(path: str) -> tuple[str, str, str]: m = re.search(r'BENCH_NAME="([^"]+)"', line) if m: benchmarks.append(m.group(1)) + elif line.startswith("export ") and "=" in line: + env_vars.append(line.removeprefix("export ").strip()) except FileNotFoundError: return user, "", comment benches = " ".join(benchmarks) if benchmarks else "default" + if env_vars: + benches += f" (env: {' '.join(env_vars)})" return user, benches, comment @@ -251,16 +264,48 @@ def list_job_files() -> list[str]: return sorted(files, key=lambda p: os.path.getmtime(p)) -# Returns list of benchmarks to run, or an empty list for the default "run benchmarks". +@dataclass +class BenchmarkRequest: + """Parsed benchmark trigger from a GitHub comment.""" + benchmarks: List[str] # empty = default suite + env_vars: List[str] = field(default_factory=list) # KEY=VALUE strings + + +# Pattern for environment variable lines: KEY=VALUE where KEY is uppercase/underscore +# and VALUE is alphanumeric with common chars (no shell metacharacters). +_ENV_VAR_RE = re.compile(r"^[A-Z_][A-Z0-9_]*=[a-zA-Z0-9._\-]+$") + + +def parse_env_vars(lines: List[str]) -> List[str]: + """Extract valid KEY=VALUE environment variables from body lines.""" + env_vars: List[str] = [] + for line in lines: + stripped = line.strip() + if not stripped: + continue + if _ENV_VAR_RE.match(stripped): + env_vars.append(stripped) + return env_vars + + +# Returns a BenchmarkRequest with benchmarks to run (empty list for the default +# "run benchmarks") and any env vars from subsequent lines. # Returns None if no trigger detected, or if any requested benchmark is unsupported. -def detect_benchmark(cfg: RepoConfig, body: str) -> List[str] | None: +def detect_benchmark(cfg: RepoConfig, body: str) -> BenchmarkRequest | None: + lines = body.strip().splitlines() + if not lines: + return None + + trigger = lines[0] + extra_lines = lines[1:] + # check for "run benchmarks" (default set) - match = re.match(r"^\s*run\s+benchmarks\s*$", body, flags=re.IGNORECASE) + match = re.match(r"^\s*run\s+benchmarks\s*$", trigger, flags=re.IGNORECASE) if match: - return [] + return BenchmarkRequest(benchmarks=[], env_vars=parse_env_vars(extra_lines)) # check for "run benchmark " - match = re.match(r"^\s*run\s+benchmark\s+([a-zA-Z0-9_\-\s]+?)\s*$", body, flags=re.IGNORECASE) + match = re.match(r"^\s*run\s+benchmark\s+([a-zA-Z0-9_\-\s]+?)\s*$", trigger, flags=re.IGNORECASE) if not match: return None @@ -269,7 +314,7 @@ def detect_benchmark(cfg: RepoConfig, body: str) -> List[str] | None: return None if all(name in cfg.allowed_standard or name in cfg.allowed_criterion for name in names): - return names + return BenchmarkRequest(benchmarks=names, env_vars=parse_env_vars(extra_lines)) return None @@ -290,9 +335,16 @@ def pr_number_from_url(url: str) -> str: # BENCHMARKS="" ./gh_compare_branch.sh https://github.com/apache/datafusion/pull/ # - If in ALLOWED_CRITERION_BENCHMARKS: # BENCH_NAME="" ./gh_compare_branch_bench.sh https://github.com/apache/datafusion/pull/ -def get_benchmark_script(cfg: RepoConfig, pr_number: str, benches: List[str]) -> str: +def get_benchmark_script(cfg: RepoConfig, pr_number: str, benches: List[str], env_vars: List[str] | None = None) -> str: pr_url = f"https://github.com/{cfg.repo}/pull/{pr_number}" commands: list[str] = [] + + # Export extra env vars (e.g. DATAFUSION_RUNTIME_MEMORY_LIMIT=1G) + if env_vars: + for ev in env_vars: + commands.append(f'export {ev}') + commands.append('') # blank line for readability + if benches: for bench in benches: if bench in cfg.allowed_criterion: @@ -408,7 +460,12 @@ def post_supported_benchmarks( f"- Standard: {supported_standard}\n" f"- Criterion: {supported_criterion}\n\n" "Please choose one or more of these with `run benchmark ` or " - "`run benchmark ...`" + "`run benchmark ...`\n\n" + "You can also set environment variables on subsequent lines:\n" + "```\n" + "run benchmark tpch_mem\n" + "DATAFUSION_RUNTIME_MEMORY_LIMIT=1G\n" + "```" f"{unsupported}" ) if already_posted(cfg, pr_number, comment_url): @@ -490,12 +547,12 @@ def process_comment(cfg: RepoConfig, comment: Mapping, now: datetime) -> None: post_queue(cfg, pr_number, login, comment_url) return - benches = detect_benchmark(cfg, body) - if benches is None: + request = detect_benchmark(cfg, body) + if request is None: print(f" No benchmark trigger detected in {body}") - if body.strip().lower().startswith("run benchmark"): + if body.strip().splitlines()[0].strip().lower().startswith("run benchmark"): print(" Comment starts with 'run benchmark' but benchmark is unsupported.") - requested = [n for n in body.split()[2:] if n] + requested = [n for n in body.strip().splitlines()[0].split()[2:] if n] if login not in ALLOWED_USERS: post_user_notice(cfg, pr_number, login, comment_url) else: @@ -507,10 +564,12 @@ def process_comment(cfg: RepoConfig, comment: Mapping, now: datetime) -> None: post_user_notice(cfg, pr_number, login, comment_url) return print(f" Found comment from allowed user: {login}") - if benches: - print(f" Benchmarks requested: {' '.join(benches)}") + if request.benchmarks: + print(f" Benchmarks requested: {' '.join(request.benchmarks)}") else: print(" Benchmarks requested: default suite") + if request.env_vars: + print(f" Environment variables: {' '.join(request.env_vars)}") file_name = job_file_name(cfg, pr_number, str(comment_id)) if os.path.exists(file_name): @@ -521,7 +580,7 @@ def process_comment(cfg: RepoConfig, comment: Mapping, now: datetime) -> None: print(f" Job done file {done_file_name} already exists, skipping") return - script_content = get_benchmark_script(cfg, pr_number, benches) + script_content = get_benchmark_script(cfg, pr_number, request.benchmarks, request.env_vars) os.makedirs("jobs", exist_ok=True) pr_url = f"https://github.com/{cfg.repo}/pull/{pr_number}" with open(file_name, "w") as f: