Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions ci/ray_ci/automation/ray_wheels_lib.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import os
from typing import List
from typing import List, Optional

import boto3

Expand Down Expand Up @@ -82,6 +82,7 @@ def download_ray_wheels_from_s3(
commit_hash: str,
ray_version: str,
directory_path: str,
branch: Optional[str] = None,
) -> None:
"""
Download Ray wheels from S3 to the given directory.
Expand All @@ -93,8 +94,10 @@ def download_ray_wheels_from_s3(
"""
full_directory_path = os.path.join(bazel_workspace_dir, directory_path)
wheels = _get_wheel_names(ray_version=ray_version)
if not branch:
branch = f"releases/{ray_version}"
for wheel in wheels:
s3_key = f"releases/{ray_version}/{commit_hash}/{wheel}.whl"
s3_key = f"{branch}/{commit_hash}/{wheel}.whl"
download_wheel_from_s3(s3_key, full_directory_path)

_check_downloaded_wheels(full_directory_path, wheels)
Expand Down
30 changes: 30 additions & 0 deletions ci/ray_ci/automation/test_ray_wheels_lib.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,36 @@ def test_download_ray_wheels_from_s3(
mock_check_wheels.assert_called_with(tmp_dir, SAMPLE_WHEELS)


@mock.patch("ci.ray_ci.automation.ray_wheels_lib.download_wheel_from_s3")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._check_downloaded_wheels")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._get_wheel_names")
def test_download_ray_wheels_from_s3_with_branch(
mock_get_wheel_names, mock_check_wheels, mock_download_wheel
):
commit_hash = "1234567"
ray_version = "1.0.0"

mock_get_wheel_names.return_value = SAMPLE_WHEELS

with tempfile.TemporaryDirectory() as tmp_dir:
download_ray_wheels_from_s3(
commit_hash=commit_hash,
ray_version=ray_version,
directory_path=tmp_dir,
branch="custom_branch",
)

mock_get_wheel_names.assert_called_with(ray_version=ray_version)
assert mock_download_wheel.call_count == len(SAMPLE_WHEELS)
for i, call_args in enumerate(mock_download_wheel.call_args_list):
assert (
call_args[0][0] == f"custom_branch/{commit_hash}/{SAMPLE_WHEELS[i]}.whl"
)
assert call_args[0][1] == tmp_dir

mock_check_wheels.assert_called_with(tmp_dir, SAMPLE_WHEELS)
Comment on lines +159 to +183
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This test function is very similar to test_download_ray_wheels_from_s3. This duplication can make the tests harder to maintain. To improve maintainability, consider refactoring this test and test_download_ray_wheels_from_s3 into a single, parameterized test using pytest.mark.parametrize. This would allow you to test both cases (with and without a branch) with a single test function, reducing code duplication.



@mock.patch("ci.ray_ci.automation.ray_wheels_lib.download_wheel_from_s3")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._check_downloaded_wheels")
@mock.patch("ci.ray_ci.automation.ray_wheels_lib._get_wheel_names")
Expand Down
8 changes: 7 additions & 1 deletion ci/ray_ci/automation/upload_wheels_pypi.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,14 +15,20 @@
@click.option("--commit_hash", required=True, type=str)
@click.option("--pypi_env", required=True, type=click.Choice(["test", "prod"]))
@click.option("--build_tag", required=False, type=str)
@click.option("--branch", required=False, type=str)
def main(
ray_version: str, commit_hash: str, pypi_env: str, build_tag: Optional[str] = None
ray_version: str,
commit_hash: str,
pypi_env: str,
branch: Optional[str] = None,
build_tag: Optional[str] = None,
):
with tempfile.TemporaryDirectory() as temp_dir:
download_ray_wheels_from_s3(
commit_hash=commit_hash,
ray_version=ray_version,
directory_path=temp_dir,
branch=branch,
)
if build_tag:
add_build_tag_to_wheels(directory_path=temp_dir, build_tag=build_tag)
Expand Down