Skip to content

Commit

Permalink
chore: add coverage to gcs_utils
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501167752
  • Loading branch information
yinghsienwu authored and copybara-github committed Jan 11, 2023
1 parent ee6bb87 commit a6a792e
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 7 deletions.
4 changes: 0 additions & 4 deletions google/cloud/aiplatform/utils/gcs_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,10 +172,6 @@ def generate_gcs_directory_for_pipeline_artifacts(
"""Gets or creates the GCS directory for Vertex Pipelines artifacts.
Args:
service_account: Optional. Google Cloud service account that will be used
to run the pipelines. If this function creates a new bucket it will give
permission to the specified service account to access the bucket.
If not provided, the Google Cloud Compute Engine service account will be used.
project: Optional. Google Cloud Project that contains the staging bucket.
location: Optional. Google Cloud location to use for the staging bucket.
Expand Down
72 changes: 69 additions & 3 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
from google.cloud.aiplatform import compat, utils
from google.cloud.aiplatform.compat.types import pipeline_failure_policy
from google.cloud.aiplatform.utils import (
gcs_utils,
pipeline_utils,
prediction_utils,
tensorboard_utils,
Expand All @@ -52,9 +53,10 @@
model_service_client_default = model_service_client_v1


GCS_BUCKET = "FAKE_BUCKET"
GCS_PREFIX = "FAKE/PREFIX"
FAKE_FILENAME = "FAKE_FILENAME"
GCS_BUCKET = "fake-bucket"
GCS_PREFIX = "fake/prefix"
FAKE_FILENAME = "fake-filename"
EXPECTED_TIME = datetime.datetime(2023, 1, 6, 8, 54, 41, 734495)


@pytest.fixture
Expand All @@ -78,6 +80,31 @@ def get_blobs(prefix):
yield mock_storage_client


@pytest.fixture()
def mock_datetime():
with patch.object(datetime, "datetime", autospec=True) as mock_datetime:
mock_datetime.now.return_value = EXPECTED_TIME
yield mock_datetime


@pytest.fixture
def mock_storage_blob_upload_from_filename():
with patch(
"google.cloud.storage.Blob.upload_from_filename"
) as mock_blob_upload_from_filename, patch(
"google.cloud.storage.Bucket.exists", return_value=True
):
yield mock_blob_upload_from_filename


@pytest.fixture()
def mock_bucket_not_exist():
with patch("google.cloud.storage.Blob.from_string") as mock_bucket_not_exist, patch(
"google.cloud.storage.Bucket.exists", return_value=False
):
yield mock_bucket_not_exist


def test_invalid_region_raises_with_invalid_region():
with pytest.raises(ValueError):
aiplatform.utils.validate_region(region="us-east5")
Expand Down Expand Up @@ -458,6 +485,45 @@ def test_timestamped_unique_name():
assert re.match(r"\d{4}-\d{2}-\d{2}-\d{2}-\d{2}-\d{2}-.{5}", name)


@pytest.mark.usefixtures("google_auth_mock")
class TestGcsUtils:
def test_upload_to_gcs(self, json_file, mock_storage_blob_upload_from_filename):
gcs_utils.upload_to_gcs(json_file, f"gs://{GCS_BUCKET}/{GCS_PREFIX}")
assert mock_storage_blob_upload_from_filename.called_once_with(json_file)

def test_stage_local_data_in_gcs(
self, json_file, mock_datetime, mock_storage_blob_upload_from_filename
):
timestamp = EXPECTED_TIME.isoformat(sep="-", timespec="milliseconds")
staging_gcs_dir = f"gs://{GCS_BUCKET}/{GCS_PREFIX}"
data_uri = gcs_utils.stage_local_data_in_gcs(json_file, staging_gcs_dir)
assert mock_storage_blob_upload_from_filename.called_once_with(json_file)
assert (
data_uri
== f"{staging_gcs_dir}/vertex_ai_auto_staging/{timestamp}/test.json"
)

def test_generate_gcs_directory_for_pipeline_artifacts(self):
output = gcs_utils.generate_gcs_directory_for_pipeline_artifacts(
"project", "us-central1"
)
assert output == "gs://project-vertex-pipelines-us-central1/output_artifacts/"

def test_create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
self, mock_bucket_not_exist, mock_storage_client
):
output = (
gcs_utils.create_gcs_bucket_for_pipeline_artifacts_if_it_does_not_exist(
project="test-project", location="us-central1"
)
)
assert mock_storage_client.called
assert mock_bucket_not_exist.called
assert (
output == "gs://test-project-vertex-pipelines-us-central1/output_artifacts/"
)


class TestPipelineUtils:
SAMPLE_JOB_SPEC = {
"pipelineSpec": {
Expand Down

0 comments on commit a6a792e

Please sign in to comment.