Skip to content

Commit

Permalink
update tests
Browse files Browse the repository at this point in the history
  • Loading branch information
dizcology committed Nov 3, 2020
1 parent 57aa05f commit 8f9c589
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
2 changes: 1 addition & 1 deletion synth.metadata
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
"git": {
"name": ".",
"remote": "https://github.com/dizcology/python-aiplatform.git",
"sha": "60263c04ffd04dabd7cc95c138b9f1c87566208c"
"sha": "81da030c0af8902fd54c8e7b5e92255a532d0efb"
}
},
{
Expand Down
39 changes: 31 additions & 8 deletions tests/unit/gapic/aiplatform_v1beta1/test_prediction_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -480,6 +480,7 @@ def test_predict(
assert args[0] == prediction_service.PredictRequest()

# Establish that the response is the type that we expect.

assert isinstance(response, prediction_service.PredictResponse)

assert response.deployed_model_id == "deployed_model_id_value"
Expand All @@ -490,14 +491,16 @@ def test_predict_from_dict():


@pytest.mark.asyncio
async def test_predict_async(transport: str = "grpc_asyncio"):
async def test_predict_async(
transport: str = "grpc_asyncio", request_type=prediction_service.PredictRequest
):
client = PredictionServiceAsyncClient(
credentials=credentials.AnonymousCredentials(), transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = prediction_service.PredictRequest()
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.predict), "__call__") as call:
Expand All @@ -514,14 +517,19 @@ async def test_predict_async(transport: str = "grpc_asyncio"):
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]

assert args[0] == request
assert args[0] == prediction_service.PredictRequest()

# Establish that the response is the type that we expect.
assert isinstance(response, prediction_service.PredictResponse)

assert response.deployed_model_id == "deployed_model_id_value"


@pytest.mark.asyncio
async def test_predict_async_from_dict():
await test_predict_async(request_type=dict)


def test_predict_field_headers():
client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),)

Expand Down Expand Up @@ -603,7 +611,9 @@ def test_predict_flattened():
]

# https://github.com/googleapis/gapic-generator-python/issues/414
# assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE)
# assert args[0].parameters == struct.Value(
# null_value=struct.NullValue.NULL_VALUE
# )


def test_predict_flattened_error():
Expand Down Expand Up @@ -702,6 +712,7 @@ def test_explain(
assert args[0] == prediction_service.ExplainRequest()

# Establish that the response is the type that we expect.

assert isinstance(response, prediction_service.ExplainResponse)

assert response.deployed_model_id == "deployed_model_id_value"
Expand All @@ -712,14 +723,16 @@ def test_explain_from_dict():


@pytest.mark.asyncio
async def test_explain_async(transport: str = "grpc_asyncio"):
async def test_explain_async(
transport: str = "grpc_asyncio", request_type=prediction_service.ExplainRequest
):
client = PredictionServiceAsyncClient(
credentials=credentials.AnonymousCredentials(), transport=transport,
)

# Everything is optional in proto3 as far as the runtime is concerned,
# and we are mocking out the actual API, so just send an empty request.
request = prediction_service.ExplainRequest()
request = request_type()

# Mock the actual call within the gRPC stub, and fake the request.
with mock.patch.object(type(client.transport.explain), "__call__") as call:
Expand All @@ -736,14 +749,19 @@ async def test_explain_async(transport: str = "grpc_asyncio"):
assert len(call.mock_calls)
_, args, _ = call.mock_calls[0]

assert args[0] == request
assert args[0] == prediction_service.ExplainRequest()

# Establish that the response is the type that we expect.
assert isinstance(response, prediction_service.ExplainResponse)

assert response.deployed_model_id == "deployed_model_id_value"


@pytest.mark.asyncio
async def test_explain_async_from_dict():
await test_explain_async(request_type=dict)


def test_explain_field_headers():
client = PredictionServiceClient(credentials=credentials.AnonymousCredentials(),)

Expand Down Expand Up @@ -826,7 +844,9 @@ def test_explain_flattened():
]

# https://github.com/googleapis/gapic-generator-python/issues/414
# assert args[0].parameters == struct.Value(null_value=struct.NullValue.NULL_VALUE)
# assert args[0].parameters == struct.Value(
# null_value=struct.NullValue.NULL_VALUE
# )

assert args[0].deployed_model_id == "deployed_model_id_value"

Expand Down Expand Up @@ -1094,6 +1114,7 @@ def test_prediction_service_grpc_transport_channel():
)
assert transport.grpc_channel == channel
assert transport._host == "squid.clam.whelk:443"
assert transport._ssl_channel_credentials == None


def test_prediction_service_grpc_asyncio_transport_channel():
Expand All @@ -1105,6 +1126,7 @@ def test_prediction_service_grpc_asyncio_transport_channel():
)
assert transport.grpc_channel == channel
assert transport._host == "squid.clam.whelk:443"
assert transport._ssl_channel_credentials == None


@pytest.mark.parametrize(
Expand Down Expand Up @@ -1152,6 +1174,7 @@ def test_prediction_service_transport_channel_mtls_with_client_cert_source(
quota_project_id=None,
)
assert transport.grpc_channel == mock_grpc_channel
assert transport._ssl_channel_credentials == mock_ssl_cred


@pytest.mark.parametrize(
Expand Down

0 comments on commit 8f9c589

Please sign in to comment.