Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ci/ray_ci/automation/ray_wheels_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List
from typing import List, Optional

import boto3

Expand Down Expand Up @@ -82,6 +82,7 @@ def download_ray_wheels_from_s3(
commit_hash: str,
ray_version: str,
directory_path: str,
branch: Optional[str] = None,
) -> None:
"""
Download Ray wheels from S3 to the given directory.
Expand All @@ -93,8 +94,10 @@ def download_ray_wheels_from_s3(
"""
full_directory_path = os.path.join(bazel_workspace_dir, directory_path)
wheels = _get_wheel_names(ray_version=ray_version)
if not branch:
branch = f"releases/{ray_version}"
for wheel in wheels:
s3_key = f"releases/{ray_version}/{commit_hash}/{wheel}.whl"
s3_key = f"{branch}/{commit_hash}/{wheel}.whl"
download_wheel_from_s3(s3_key, full_directory_path)

_check_downloaded_wheels(full_directory_path, wheels)
Expand Down
30 changes: 30 additions & 0 deletions ci/ray_ci/automation/test_ray_wheels_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ def test_download_ray_wheels_from_s3(
mock_check_wheels.assert_called_with(tmp_dir, SAMPLE_WHEELS)


@mock.patch("ci.ray_ci.automation.ray_wheels_lib.download_wheel_from_s3")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._check_downloaded_wheels")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._get_wheel_names")
def test_download_ray_wheels_from_s3_with_branch(
mock_get_wheel_names, mock_check_wheels, mock_download_wheel
):
commit_hash = "1234567"
ray_version = "1.0.0"

mock_get_wheel_names.return_value = SAMPLE_WHEELS

with tempfile.TemporaryDirectory() as tmp_dir:
download_ray_wheels_from_s3(
commit_hash=commit_hash,
ray_version=ray_version,
directory_path=tmp_dir,
branch="custom_branch",
)

mock_get_wheel_names.assert_called_with(ray_version=ray_version)
assert mock_download_wheel.call_count == len(SAMPLE_WHEELS)
for i, call_args in enumerate(mock_download_wheel.call_args_list):
assert (
call_args[0][0] == f"custom_branch/{commit_hash}/{SAMPLE_WHEELS[i]}.whl"
)
assert call_args[0][1] == tmp_dir

mock_check_wheels.assert_called_with(tmp_dir, SAMPLE_WHEELS)
Comment on lines +159 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test function is very similar to test_download_ray_wheels_from_s3. This duplication can make the tests harder to maintain. To improve maintainability, consider refactoring this test and test_download_ray_wheels_from_s3 into a single, parameterized test using pytest.mark.parametrize. This would allow you to test both cases (with and without a branch) with a single test function, reducing code duplication.



@mock.patch("ci.ray_ci.automation.ray_wheels_lib.download_wheel_from_s3")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._check_downloaded_wheels")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._get_wheel_names")
Expand Down
8 changes: 7 additions & 1 deletion ci/ray_ci/automation/upload_wheels_pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,21 @@
@click.option("--ray_version", required=True, type=str)
@click.option("--commit_hash", required=True, type=str)
@click.option("--pypi_env", required=True, type=click.Choice(["test", "prod"]))
@click.option("--branch", required=False, type=str)
@click.option("--build_tag", required=False, type=str)
def main(
ray_version: str, commit_hash: str, pypi_env: str, build_tag: Optional[str] = None
ray_version: str,
commit_hash: str,
pypi_env: str,
branch: Optional[str] = None,
build_tag: Optional[str] = None,
):
with tempfile.TemporaryDirectory() as temp_dir:
download_ray_wheels_from_s3(
commit_hash=commit_hash,
ray_version=ray_version,
directory_path=temp_dir,
branch=branch,
)
if build_tag:
add_build_tag_to_wheels(directory_path=temp_dir, build_tag=build_tag)
Expand Down