Skip to content

Commit b7ea62b

Browse files
authored
ci: Add ability to array-ify args and run multiple jobs (#3584)
# Overview Previously, the `run-cluster` workflow only ran one ray-job-submission. This PR extends the ability to be able to run any arbitrary array of job submissions by enabling us to pass an array into the `entrypoint_args` input param. This then splits the command into its multiple pieces and submits them all. ## Example Usage ```sh gh workflow run run-cluster.yaml \ --ref $current_branch \ -f working_dir="." \ -f daft_wheel_url="https://github-actions-artifacts-bucket.s3.us-west-2.amazonaws.com/builds/54428e3738e96764af60cfdd8a0e4a41717ec9f9/getdaft-0.3.0.dev0-cp38-abi3-manylinux_2_31_x86_64.whl" \ -f entrypoint_script="benchmarking/tpcds/ray_entrypoint.py" \ -f entrypoint_args="[\"--tpcds-gen-folder='gendata' --question='1'\", \"--tpcds-gen-folder='gendata' --question='2'\"]" ``` The above invocation runs TPC-DS queries 1 and 2.
1 parent 47f5897 commit b7ea62b

File tree

4 files changed

+180
-36
lines changed

4 files changed

+180
-36
lines changed

.github/ci-scripts/job_runner.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
# /// script
2+
# requires-python = ">=3.12"
3+
# dependencies = []
4+
# ///
5+
6+
import argparse
7+
import asyncio
8+
import json
9+
from dataclasses import dataclass
10+
from datetime import datetime, timedelta
11+
from pathlib import Path
12+
from typing import Optional
13+
14+
from ray.job_submission import JobStatus, JobSubmissionClient
15+
16+
17+
def parse_env_var_str(env_var_str: str) -> dict:
18+
iter = map(
19+
lambda s: s.strip().split("="),
20+
filter(lambda s: s, env_var_str.split(",")),
21+
)
22+
return {k: v for k, v in iter}
23+
24+
25+
async def print_logs(logs):
26+
async for lines in logs:
27+
print(lines, end="")
28+
29+
30+
async def wait_on_job(logs, timeout_s):
31+
await asyncio.wait_for(print_logs(logs), timeout=timeout_s)
32+
33+
34+
@dataclass
35+
class Result:
36+
query: int
37+
duration: timedelta
38+
error_msg: Optional[str]
39+
40+
41+
def submit_job(
42+
working_dir: Path,
43+
entrypoint_script: str,
44+
entrypoint_args: str,
45+
env_vars: str,
46+
enable_ray_tracing: bool,
47+
):
48+
env_vars_dict = parse_env_var_str(env_vars)
49+
if enable_ray_tracing:
50+
env_vars_dict["DAFT_ENABLE_RAY_TRACING"] = "1"
51+
52+
client = JobSubmissionClient(address="http://localhost:8265")
53+
54+
if entrypoint_args.startswith("[") and entrypoint_args.endswith("]"):
55+
# this is a json-encoded list of strings; parse accordingly
56+
list_of_entrypoint_args: list[str] = json.loads(entrypoint_args)
57+
else:
58+
list_of_entrypoint_args: list[str] = [entrypoint_args]
59+
60+
results = []
61+
62+
for index, args in enumerate(list_of_entrypoint_args):
63+
entrypoint = f"DAFT_RUNNER=ray python {entrypoint_script} {args}"
64+
print(f"{entrypoint=}")
65+
start = datetime.now()
66+
job_id = client.submit_job(
67+
entrypoint=entrypoint,
68+
runtime_env={
69+
"working_dir": working_dir,
70+
"env_vars": env_vars_dict,
71+
},
72+
)
73+
74+
asyncio.run(wait_on_job(client.tail_job_logs(job_id), timeout_s=60 * 30))
75+
76+
status = client.get_job_status(job_id)
77+
assert status.is_terminal(), "Job should have terminated"
78+
end = datetime.now()
79+
duration = end - start
80+
error_msg = None
81+
if status != JobStatus.SUCCEEDED:
82+
job_info = client.get_job_info(job_id)
83+
error_msg = job_info.message
84+
85+
result = Result(query=index, duration=duration, error_msg=error_msg)
86+
results.append(result)
87+
88+
print(f"{results=}")
89+
90+
91+
if __name__ == "__main__":
92+
parser = argparse.ArgumentParser()
93+
parser.add_argument("--working-dir", type=Path, required=True)
94+
parser.add_argument("--entrypoint-script", type=str, required=True)
95+
parser.add_argument("--entrypoint-args", type=str, required=True)
96+
parser.add_argument("--env-vars", type=str, required=True)
97+
parser.add_argument("--enable-ray-tracing", action="store_true")
98+
99+
args = parser.parse_args()
100+
101+
if not (args.working_dir.exists() and args.working_dir.is_dir()):
102+
raise ValueError("The working-dir must exist and be a directory")
103+
104+
entrypoint: Path = args.working_dir / args.entrypoint_script
105+
if not (entrypoint.exists() and entrypoint.is_file()):
106+
raise ValueError("The entrypoint script must exist and be a file")
107+
108+
submit_job(
109+
working_dir=args.working_dir,
110+
entrypoint_script=args.entrypoint_script,
111+
entrypoint_args=args.entrypoint_args,
112+
env_vars=args.env_vars,
113+
enable_ray_tracing=args.enable_ray_tracing,
114+
)

.github/ci-scripts/templatize_ray_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -110,5 +110,7 @@ class Metadata(BaseModel, extra="allow"):
110110
if metadata:
111111
metadata = Metadata(**metadata)
112112
content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies))
113+
else:
114+
content = content.replace(OTHER_INSTALL_PLACEHOLDER, "")
113115

114116
print(content)

.github/workflows/run-cluster.yaml

Lines changed: 14 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ on:
3434
type: string
3535
required: true
3636
entrypoint_args:
37-
description: Entry-point arguments
37+
description: Entry-point arguments (either a simple string or a JSON list)
3838
type: string
3939
required: false
4040
default: ""
@@ -79,24 +79,15 @@ jobs:
7979
uv run \
8080
--python 3.12 \
8181
.github/ci-scripts/templatize_ray_config.py \
82-
--cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
83-
--daft-wheel-url '${{ inputs.daft_wheel_url }}' \
84-
--daft-version '${{ inputs.daft_version }}' \
85-
--python-version '${{ inputs.python_version }}' \
86-
--cluster-profile '${{ inputs.cluster_profile }}' \
87-
--working-dir '${{ inputs.working_dir }}' \
88-
--entrypoint-script '${{ inputs.entrypoint_script }}'
82+
--cluster-name="ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \
83+
--daft-wheel-url='${{ inputs.daft_wheel_url }}' \
84+
--daft-version='${{ inputs.daft_version }}' \
85+
--python-version='${{ inputs.python_version }}' \
86+
--cluster-profile='${{ inputs.cluster_profile }}' \
87+
--working-dir='${{ inputs.working_dir }}' \
88+
--entrypoint-script='${{ inputs.entrypoint_script }}'
8989
) >> .github/assets/ray.yaml
9090
cat .github/assets/ray.yaml
91-
- name: Setup ray env vars
92-
run: |
93-
source .venv/bin/activate
94-
ray_env_var=$(python .github/ci-scripts/format_env_vars.py \
95-
--env-vars '${{ inputs.env_vars }}' \
96-
--enable-ray-tracing \
97-
)
98-
echo $ray_env_var
99-
echo "ray_env_var=$ray_env_var" >> $GITHUB_ENV
10091
- name: Download private ssh key
10192
run: |
10293
KEY=$(aws secretsmanager get-secret-value --secret-id ci-github-actions-ray-cluster-key-3 --query SecretString --output text)
@@ -117,11 +108,12 @@ jobs:
117108
echo 'Invalid command submitted; command cannot be empty'
118109
exit 1
119110
fi
120-
ray job submit \
121-
--working-dir ${{ inputs.working_dir }} \
122-
--address http://localhost:8265 \
123-
--runtime-env-json "$ray_env_var" \
124-
-- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }}
111+
python .github/ci-scripts/job_runner.py \
112+
--working-dir='${{ inputs.working_dir }}' \
113+
--entrypoint-script='${{ inputs.entrypoint_script }}' \
114+
--entrypoint-args='${{ inputs.entrypoint_args }}' \
115+
--env-vars='${{ inputs.env_vars }}' \
116+
--enable-ray-tracing
125117
- name: Download log files from ray cluster
126118
run: |
127119
source .venv/bin/activate

benchmarking/tpcds/ray_entrypoint.py

Lines changed: 50 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,54 @@
11
import argparse
22
from pathlib import Path
33

4-
import helpers
5-
64
import daft
5+
from daft.sql.sql import SQLCatalog
6+
7+
TABLE_NAMES = [
8+
"call_center",
9+
"catalog_page",
10+
"catalog_returns",
11+
"catalog_sales",
12+
"customer",
13+
"customer_address",
14+
"customer_demographics",
15+
"date_dim",
16+
"household_demographics",
17+
"income_band",
18+
"inventory",
19+
"item",
20+
"promotion",
21+
"reason",
22+
"ship_mode",
23+
"store",
24+
"store_returns",
25+
"store_sales",
26+
"time_dim",
27+
"warehouse",
28+
"web_page",
29+
"web_returns",
30+
"web_sales",
31+
"web_site",
32+
]
33+
34+
35+
def register_catalog(scale_factor: int) -> SQLCatalog:
36+
return SQLCatalog(
37+
tables={
38+
table: daft.read_parquet(
39+
f"s3://eventual-dev-benchmarking-fixtures/uncompressed/tpcds-dbgen/{scale_factor}/{table}.parquet"
40+
)
41+
for table in TABLE_NAMES
42+
}
43+
)
744

845

946
def run(
10-
parquet_folder: Path,
1147
question: int,
1248
dry_run: bool,
49+
scale_factor: int,
1350
):
14-
catalog = helpers.generate_catalog(parquet_folder)
51+
catalog = register_catalog(scale_factor)
1552
query_file = Path(__file__).parent / "queries" / f"{question:02}.sql"
1653
with open(query_file) as f:
1754
query = f.read()
@@ -23,27 +60,26 @@ def run(
2360

2461
if __name__ == "__main__":
2562
parser = argparse.ArgumentParser()
26-
parser.add_argument(
27-
"--tpcds-gen-folder",
28-
required=True,
29-
type=Path,
30-
help="Path to the TPC-DS data generation folder",
31-
)
3263
parser.add_argument(
3364
"--question",
34-
required=True,
3565
type=int,
3666
help="The TPC-DS question index to run",
67+
required=True,
3768
)
3869
parser.add_argument(
3970
"--dry-run",
4071
action="store_true",
4172
help="Whether or not to run the query in dry-run mode; if true, only the plan will be printed out",
4273
)
74+
parser.add_argument(
75+
"--scale-factor",
76+
type=int,
77+
help="Which scale factor to run this data at",
78+
required=False,
79+
default=2,
80+
)
4381
args = parser.parse_args()
4482

45-
tpcds_gen_folder: Path = args.tpcds_gen_folder
46-
assert tpcds_gen_folder.exists()
4783
assert args.question in range(1, 100)
4884

49-
run(args.tpcds_gen_folder, args.question, args.dry_run)
85+
run(question=args.question, dry_run=args.dry_run, scale_factor=args.scale_factor)

0 commit comments

Comments
 (0)