Skip to content
220 changes: 58 additions & 162 deletions build_tools/fetch_artifacts.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,110 +25,55 @@
AWS_SESSION_TOKEN
if and only if all are specified in the environment to connect with S3.
If unspecified, we will create an anonymous boto file that can only acccess public artifacts.

TODO: Evaluate switching to artifact_manager.py which provides a unified backend
abstraction (local directory or S3) and integrates with BUILD_TOPOLOGY.toml for
stage-aware artifact filtering.
"""

import argparse
import boto3
from botocore import UNSIGNED
from botocore.config import Config
import concurrent.futures
from dataclasses import dataclass, field
import os
from pathlib import Path
import platform
import re
import shutil
import sys
import tarfile
import time
from urllib3.exceptions import InsecureRequestWarning
import warnings

from _therock_utils.artifact_backend import ArtifactBackend, S3Backend
from _therock_utils.artifacts import (
ArtifactName,
ArtifactPopulator,
_open_archive_for_read,
)
from artifact_manager import DownloadRequest, download_artifact
from github_actions.github_actions_utils import retrieve_bucket_info


warnings.filterwarnings("ignore", category=InsecureRequestWarning)

_access_key_id = os.environ.get("AWS_ACCESS_KEY_ID")
_secret_access_key = os.environ.get("AWS_SECRET_ACCESS_KEY")
_session_token = os.environ.get("AWS_SESSION_TOKEN")

# Create S3 client leveraging AWS credentials if available.
if None not in (_access_key_id, _secret_access_key, _session_token):
s3_client = boto3.client(
"s3",
verify=False,
aws_access_key_id=_access_key_id,
aws_secret_access_key=_secret_access_key,
aws_session_token=_session_token,
)
else:
# Otherwise use anonymous boto file.
s3_client = boto3.client(
"s3",
verify=False,
config=Config(max_pool_connections=100, signature_version=UNSIGNED),
)

paginator = s3_client.get_paginator("list_objects_v2")


# TODO(geomin12): switch out logging library
def log(*args, **kwargs):
print(*args, **kwargs)
sys.stdout.flush()


# TODO: move into github_actions_utils.py?
@dataclass
class BucketMetadata:
"""Metadata for a workflow run's artifacts in an AWS S3 bucket."""
def list_artifacts_for_group(backend: ArtifactBackend, artifact_group: str) -> set[str]:
"""Lists artifacts from backend, filtered by artifact_group.

external_repo: str
bucket: str
workflow_run_id: str
platform: str
s3_key_path: str = field(init=False)
Args:
backend: ArtifactBackend instance configured for the target run
artifact_group: GPU family to filter by (e.g., "gfx94X-all"). Also includes
artifacts with "generic" in the name.

def __post_init__(self):
self.s3_key_path = f"{self.external_repo}{self.workflow_run_id}-{self.platform}"
Returns:
Set of artifact filenames matching the artifact_group or "generic".
"""
log(f"Retrieving artifacts from '{backend.base_uri}'")

# Get all artifacts from backend
all_artifacts = backend.list_artifacts()

def list_s3_artifacts(bucket_info: BucketMetadata, artifact_group: str) -> set[str]:
"""Checks that the AWS S3 bucket exists and returns artifact names."""
s3_key_path = bucket_info.s3_key_path
log(
f"Retrieving S3 artifacts for {bucket_info.workflow_run_id} in '{bucket_info.bucket}' at '{s3_key_path}'"
)

page_iterator = paginator.paginate(Bucket=bucket_info.bucket, Prefix=s3_key_path)
# Filter by artifact_group (matches if artifact_group or "generic" in filename)
data = set()
for page in page_iterator:
if not "Contents" in page:
continue

for artifact in page["Contents"]:
artifact_key = artifact["Key"]
# Match both .tar.zst (new) and .tar.xz (legacy) formats
is_artifact_archive = "tar.zst" in artifact_key or "tar.xz" in artifact_key
if (
"sha256sum" not in artifact_key
and is_artifact_archive
and (artifact_group in artifact_key or "generic" in artifact_key)
):
file_name = artifact_key.split("/")[-1]
data.add(file_name)
for filename in all_artifacts:
if artifact_group in filename or "generic" in filename:
data.add(filename)

if not data:
log(f"Found no S3 artifacts for {bucket_info.run_id} at '{s3_key_path}'")
log(f"Found no artifacts matching '{artifact_group}' at '{backend.base_uri}'")
return data


Expand Down Expand Up @@ -163,75 +108,6 @@ def _should_include(artifact_name: str) -> bool:
return {a for a in artifacts if _should_include(a)}


@dataclass
class ArtifactDownloadRequest:
"""Information about a request to download an artifact to a local path."""

artifact_key: str
bucket: str
output_path: Path

def __str__(self):
return f"{self.bucket}:{self.artifact_key}"


def download_artifact(
artifact_download_request: ArtifactDownloadRequest,
) -> ArtifactDownloadRequest:
MAX_RETRIES = 3
BASE_DELAY = 3 # seconds
for attempt in range(MAX_RETRIES):
try:
artifact_key = artifact_download_request.artifact_key
bucket = artifact_download_request.bucket
output_path = artifact_download_request.output_path
log(f"++ Downloading {artifact_key} to {output_path}")
with open(output_path, "wb") as f:
s3_client.download_fileobj(bucket, artifact_key, f)
log(f"++ Download complete for {output_path}")
return artifact_download_request
except Exception as e:
log(f"++ Error downloading {artifact_key}: {e}")
if attempt < MAX_RETRIES - 1:
delay = BASE_DELAY * (2**attempt)
print(f"Retrying in {delay} seconds...")
time.sleep(delay)
else:
log(
f"++ Failed downloading from {artifact_key} after {MAX_RETRIES} retries"
)


def download_artifacts(artifact_download_requests: list[ArtifactDownloadRequest]):
"""Downloads artifacts in parallel using a thread pool executor."""
with concurrent.futures.ThreadPoolExecutor() as executor:
futures = [
executor.submit(download_artifact, artifact_download_request)
for artifact_download_request in artifact_download_requests
]
for future in concurrent.futures.as_completed(futures):
future.result(timeout=60)


def get_artifact_download_requests(
bucket_info: BucketMetadata,
s3_artifacts: set[str],
output_dir: Path,
) -> list[ArtifactDownloadRequest]:
"""Gets artifact download requests from requested artifacts."""
artifacts_to_download = []

for artifact in sorted(list(s3_artifacts)):
artifacts_to_download.append(
ArtifactDownloadRequest(
artifact_key=f"{bucket_info.s3_key_path}/{artifact}",
bucket=bucket_info.bucket,
output_path=output_dir / artifact,
)
)
return artifacts_to_download


def get_postprocess_mode(args) -> str | None:
"""Returns 'extract', 'flatten' or None (default is 'extract')."""
if args.flatten:
Expand All @@ -242,13 +118,22 @@ def get_postprocess_mode(args) -> str | None:


def extract_artifact(
artifact: ArtifactDownloadRequest, *, delete_archive: bool, postprocess_mode: str
archive_path: Path, *, delete_archive: bool, postprocess_mode: str
Comment thread
ScottTodd marked this conversation as resolved.
):
"""Extracts and postprocesses an artifact from an archive file in-place.

Args:
archive_path: Path to the archive (e.g. `amd-llvm_lib_generic.tar.xz`)
delete_archive: True to delete the archive after extraction
postprocess_mode: Either 'flatten' or 'extract'
* 'flatten' merges artifacts into a single "dist/" directory
* 'extract' puts each artifact in a dir (e.g. `amd-llvm_lib_generic/`)
"""
# Get (for example) 'amd-llvm_lib_generic' from '/path/to/amd-llvm_lib_generic.tar.xz'
# We can't just use .stem since that only removes the last extension.
# 1. .name gets us 'amd-llvm_lib_generic.tar.xz'
# 2. .partition('.') gets (before, sep, after), discard all but 'before'
archive_file = artifact.output_path
archive_file = archive_path
artifact_name, *_ = archive_file.name.partition(".")

if postprocess_mode == "extract":
Expand Down Expand Up @@ -277,56 +162,64 @@ def run(args):
run_id = args.run_id
artifact_group = args.artifact_group
output_dir = args.output_dir
output_dir.mkdir(parents=True, exist_ok=True)

external_repo, bucket = retrieve_bucket_info(
github_repository=run_github_repo,
workflow_run_id=run_id,
)
bucket_info = BucketMetadata(
external_repo=external_repo,
backend = S3Backend(
bucket=bucket,
workflow_run_id=run_id,
run_id=run_id,
platform=args.platform,
external_repo=external_repo,
)

# Lookup which artifacts exist in the bucket.
# Note: this currently does not check that all requested artifacts
# (via include patterns) do exist, so this may silently fail to fetch
# expected files.
s3_artifacts = list_s3_artifacts(
bucket_info=bucket_info, artifact_group=artifact_group
available_artifacts = list_artifacts_for_group(
backend=backend, artifact_group=artifact_group
)
if not s3_artifacts:
if not available_artifacts:
log(f"No matching artifacts for {run_id} exist. Exiting...")
sys.exit(1)

# Include/exclude filtering.
s3_artifacts_filtered = filter_artifacts(s3_artifacts, args.include, args.exclude)
if not s3_artifacts_filtered:
filtered_artifacts = filter_artifacts(
available_artifacts, args.include, args.exclude
)
if not filtered_artifacts:
log(f"Filtering artifacts for {run_id} resulted in an empty set. Exiting...")
sys.exit(1)

artifacts_to_download = get_artifact_download_requests(
bucket_info=bucket_info,
s3_artifacts=s3_artifacts_filtered,
output_dir=output_dir,
)
download_requests = [
DownloadRequest(
artifact_key=artifact,
dest_path=output_dir / artifact,
backend=backend,
)
for artifact in sorted(filtered_artifacts)
]

download_summary = "\n ".join([str(item) for item in artifacts_to_download])
download_summary = "\n ".join(
[f"{req.backend.base_uri}/{req.artifact_key}" for req in download_requests]
)
log(f"\nFiltered artifacts to download:\n {download_summary}\n")

if args.dry_run:
log("Skipping downloads since --dry-run was set")
return

output_dir.mkdir(parents=True, exist_ok=True)

# Download and extract in parallel.
with concurrent.futures.ThreadPoolExecutor(
max_workers=args.download_concurrency
) as download_executor:
download_futures = [
download_executor.submit(download_artifact, req)
for req in artifacts_to_download
for req in download_requests
]

postprocess_mode = get_postprocess_mode(args)
Expand All @@ -340,7 +233,10 @@ def run(args):
) as extract_executor:
extract_futures: list[concurrent.futures.Future] = []
for download_future in concurrent.futures.as_completed(download_futures):
# download_artifact returns Optional[Path] - None on failure
download_result = download_future.result(timeout=60)
if download_result is None:
continue
extract_futures.append(
extract_executor.submit(
extract_artifact,
Expand Down
Loading
Loading