Skip to content

Commit

Permalink
feat: Update Ray system tests to be compatible with new RoV 2.33 changes
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 673047153
  • Loading branch information
speedstorm1 authored and copybara-github committed Sep 10, 2024
1 parent 424ebbf commit 8c7bf27
Show file tree
Hide file tree
Showing 5 changed files with 142 additions and 105 deletions.
206 changes: 109 additions & 97 deletions google/cloud/aiplatform/vertex_ray/bigquery_datasink.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@
from ray.data._internal.remote_fn import cached_remote_fn
from ray.data.block import Block, BlockAccessor

from ray.data.datasource.datasink import Datasink
try:
from ray.data.datasource.datasink import Datasink
except ImportError:
# If datasink cannot be imported, Ray >=2.9.3 is not installed
Datasink = None


DEFAULT_MAX_RETRY_CNT = 10
Expand All @@ -48,102 +52,110 @@


# BigQuery write for Ray 2.33.0 and 2.9.3
class _BigQueryDatasink(Datasink):
def __init__(
self,
dataset: str,
project_id: Optional[str] = None,
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
overwrite_table: Optional[bool] = True,
) -> None:
self.dataset = dataset
self.project_id = project_id or initializer.global_config.project
self.max_retry_cnt = max_retry_cnt
self.overwrite_table = overwrite_table

def on_write_start(self) -> None:
# Set up datasets to write
client = bigquery.Client(project=self.project_id, client_info=bq_info)
dataset_id = self.dataset.split(".", 1)[0]
try:
client.get_dataset(dataset_id)
except exceptions.NotFound:
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
print("[Ray on Vertex AI]: Created dataset " + dataset_id)

# Delete table if overwrite_table is True
if self.overwrite_table:
print(
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
+ " if it already exists since kwarg overwrite_table = True."
)
client.delete_table(f"{self.project_id}.{self.dataset}", not_found_ok=True)
else:
print(
"[Ray on Vertex AI]: The write will append to table "
+ f"{self.dataset} if it already exists "
+ "since kwarg overwrite_table = False."
)

def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> Any:
def _write_single_block(block: Block, project_id: str, dataset: str) -> None:
block = BlockAccessor.for_block(block).to_arrow()

client = bigquery.Client(project=project_id, client_info=bq_info)
job_config = bigquery.LoadJobConfig(autodetect=True)
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND

with tempfile.TemporaryDirectory() as temp_dir:
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
pq.write_table(block, fp, compression="SNAPPY")

retry_cnt = 0
while retry_cnt <= self.max_retry_cnt:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
)
try:
logging.info(job.result())
break
except exceptions.Forbidden as e:
retry_cnt += 1
if retry_cnt > self.max_retry_cnt:
if Datasink is None:
_BigQueryDatasink = None
else:

class _BigQueryDatasink(Datasink):
def __init__(
self,
dataset: str,
project_id: Optional[str] = None,
max_retry_cnt: int = DEFAULT_MAX_RETRY_CNT,
overwrite_table: Optional[bool] = True,
) -> None:
self.dataset = dataset
self.project_id = project_id or initializer.global_config.project
self.max_retry_cnt = max_retry_cnt
self.overwrite_table = overwrite_table

def on_write_start(self) -> None:
# Set up datasets to write
client = bigquery.Client(project=self.project_id, client_info=bq_info)
dataset_id = self.dataset.split(".", 1)[0]
try:
client.get_dataset(dataset_id)
except exceptions.NotFound:
client.create_dataset(f"{self.project_id}.{dataset_id}", timeout=30)
print("[Ray on Vertex AI]: Created dataset " + dataset_id)

# Delete table if overwrite_table is True
if self.overwrite_table:
print(
f"[Ray on Vertex AI]: Attempting to delete table {self.dataset}"
+ " if it already exists since kwarg overwrite_table = True."
)
client.delete_table(
f"{self.project_id}.{self.dataset}", not_found_ok=True
)
else:
print(
"[Ray on Vertex AI]: The write will append to table "
+ f"{self.dataset} if it already exists "
+ "since kwarg overwrite_table = False."
)

def write(
self,
blocks: Iterable[Block],
ctx: TaskContext,
) -> Any:
def _write_single_block(
block: Block, project_id: str, dataset: str
) -> None:
block = BlockAccessor.for_block(block).to_arrow()

client = bigquery.Client(project=project_id, client_info=bq_info)
job_config = bigquery.LoadJobConfig(autodetect=True)
job_config.source_format = bigquery.SourceFormat.PARQUET
job_config.write_disposition = bigquery.WriteDisposition.WRITE_APPEND

with tempfile.TemporaryDirectory() as temp_dir:
fp = os.path.join(temp_dir, f"block_{uuid.uuid4()}.parquet")
pq.write_table(block, fp, compression="SNAPPY")

retry_cnt = 0
while retry_cnt <= self.max_retry_cnt:
with open(fp, "rb") as source_file:
job = client.load_table_from_file(
source_file, dataset, job_config=job_config
)
try:
logging.info(job.result())
break
except exceptions.Forbidden as e:
retry_cnt += 1
if retry_cnt > self.max_retry_cnt:
break
print(
"[Ray on Vertex AI]: A block write encountered"
+ f" a rate limit exceeded error {retry_cnt} time(s)."
+ " Sleeping to try again."
)
logging.debug(e)
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)

# Raise exception if retry_cnt exceeds max_retry_cnt
if retry_cnt > self.max_retry_cnt:
print(
"[Ray on Vertex AI]: A block write encountered"
+ f" a rate limit exceeded error {retry_cnt} time(s)."
+ " Sleeping to try again."
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
+ " Ray will attempt to retry the block write via fault tolerance."
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
)
raise RuntimeError(
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
+ " repeated API rate limit exceeded responses. Consider"
+ " specifiying the max_retry_cnt kwarg with a higher value."
)
logging.debug(e)
time.sleep(RATE_LIMIT_EXCEEDED_SLEEP_TIME)

# Raise exception if retry_cnt exceeds max_retry_cnt
if retry_cnt > self.max_retry_cnt:
print(
f"[Ray on Vertex AI]: Maximum ({self.max_retry_cnt}) retry count exceeded."
+ " Ray will attempt to retry the block write via fault tolerance."
+ " For more information, see https://docs.ray.io/en/latest/ray-core/fault_tolerance/tasks.html"
)
raise RuntimeError(
f"[Ray on Vertex AI]: Write failed due to {retry_cnt}"
+ " repeated API rate limit exceeded responses. Consider"
+ " specifiying the max_retry_cnt kwarg with a higher value."
)

_write_single_block = cached_remote_fn(_write_single_block)

# Launch a remote task for each block within this write task
ray.get(
[
_write_single_block.remote(block, self.project_id, self.dataset)
for block in blocks
]
)

return "ok"

_write_single_block = cached_remote_fn(_write_single_block)

# Launch a remote task for each block within this write task
ray.get(
[
_write_single_block.remote(block, self.project_id, self.dataset)
for block in blocks
]
)

return "ok"
9 changes: 6 additions & 3 deletions google/cloud/aiplatform/vertex_ray/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,9 +23,12 @@
_BigQueryDatasource,
)

from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
_BigQueryDatasink,
)
try:
from google.cloud.aiplatform.vertex_ray.bigquery_datasink import (
_BigQueryDatasink,
)
except ImportError:
_BigQueryDatasink = None

from google.cloud.aiplatform.vertex_ray.util._validation_utils import (
_V2_4_WARNING_MESSAGE,
Expand Down
4 changes: 2 additions & 2 deletions tests/system/vertex_ray/test_cluster_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
import pytest
import ray

# Local ray version will always be 2.4 regardless of cluster version due to
# Local ray version will always be 2.4.0 regardless of cluster version due to
# depenency conflicts. Remote job execution's Ray version is 2.9.
RAY_VERSION = "2.4.0"
PROJECT_ID = "ucaip-sample-tests"
Expand All @@ -31,7 +31,7 @@
class TestClusterManagement(e2e_base.TestEndToEnd):
_temp_prefix = "temp-rov-cluster-management"

@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
def test_cluster_management(self, cluster_ray_version):
assert ray.__version__ == RAY_VERSION
aiplatform.init(project=PROJECT_ID, location="us-central1")
Expand Down
2 changes: 1 addition & 1 deletion tests/system/vertex_ray/test_job_submission_dashboard.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
class TestJobSubmissionDashboard(e2e_base.TestEndToEnd):
_temp_prefix = "temp-job-submission-dashboard"

@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
def test_job_submission_dashboard(self, cluster_ray_version):
assert ray.__version__ == RAY_VERSION
aiplatform.init(project=PROJECT_ID, location="us-central1")
Expand Down
26 changes: 24 additions & 2 deletions tests/system/vertex_ray/test_ray_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,13 +54,35 @@
)
"""

my_script = {"2.9": my_script_ray29}
my_script_ray233 = """
import ray
import vertex_ray
override_num_blocks = 10
query = "SELECT * FROM `bigquery-public-data.ml_datasets.ulb_fraud_detection` LIMIT 10000000"
ds = vertex_ray.data.read_bigquery(
override_num_blocks=override_num_blocks,
query=query,
)
# The reads are lazy, so the end time cannot be captured until ds.materialize() is called
ds.materialize()
# Write
vertex_ray.data.write_bigquery(
ds,
dataset="bugbashbq1.system_test_ray29_write",
)
"""

my_script = {"2.9": my_script_ray29, "2.33": my_script_ray233}


class TestRayData(e2e_base.TestEndToEnd):
_temp_prefix = "temp-ray-data"

@pytest.mark.parametrize("cluster_ray_version", ["2.9"])
@pytest.mark.parametrize("cluster_ray_version", ["2.9", "2.33"])
def test_ray_data(self, cluster_ray_version):
head_node_type = vertex_ray.Resources()
worker_node_types = [
Expand Down

0 comments on commit 8c7bf27

Please sign in to comment.