Skip to content

Commit

Permalink
fix: address broken unit tests in certain environments
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 501875885
  • Loading branch information
ucdmkt authored and copybara-github committed Jan 13, 2023
1 parent 9ffd173 commit d06b22d
Show file tree
Hide file tree
Showing 4 changed files with 28 additions and 6 deletions.
29 changes: 23 additions & 6 deletions google/cloud/aiplatform/vizier/pyvizier/study_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,19 +117,36 @@ class SearchSpace(SearchSpace):
@classmethod
def from_proto(cls, proto: study_pb2.StudySpec) -> "SearchSpace":
"""Extracts a SearchSpace object from a StudyConfig proto."""
parameter_configs = []

# For google-vizier <= 0.0.15
if hasattr(cls, "_factory"):
parameter_configs = []
for pc in proto.parameters:
parameter_configs.append(
proto_converters.ParameterConfigConverter.from_proto(pc)
)
return cls._factory(parameter_configs=parameter_configs)

result = cls()
for pc in proto.parameters:
parameter_configs.append(
proto_converters.ParameterConfigConverter.from_proto(pc)
)
return cls._factory(parameter_configs=parameter_configs)
result.add(proto_converters.ParameterConfigConverter.from_proto(pc))

return result

@property
def parameter_protos(self) -> List[study_pb2.StudySpec.ParameterSpec]:
"""Returns the search space as a List of ParameterConfig protos."""

# For google-vizier <= 0.0.15
if isinstance(self._parameter_configs, list):
return [
proto_converters.ParameterConfigConverter.to_proto(pc)
for pc in self._parameter_configs
]

return [
proto_converters.ParameterConfigConverter.to_proto(pc)
for pc in self._parameter_configs
for _, pc in self._parameter_configs.items()
]


Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_metadata_resources.py
Original file line number Diff line number Diff line change
Expand Up @@ -614,6 +614,7 @@ def list_artifact_empty_mock():
yield list_artifacts_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestExecution:
def setup_method(self):
reload(initializer)
Expand Down Expand Up @@ -893,6 +894,7 @@ def test_query_input_and_output_artifacts(
assert artifact_list[0]._gca_resource == expected_artifact


@pytest.mark.usefixtures("google_auth_mock")
class TestArtifact:
def setup_method(self):
reload(initializer)
Expand Down
1 change: 1 addition & 0 deletions tests/unit/aiplatform/test_metadata_store.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def delete_metadata_store_mock():
yield delete_metadata_store_mock


@pytest.mark.usefixtures("google_auth_mock")
class TestMetadataStore:
def setup_method(self):
reload(initializer)
Expand Down
2 changes: 2 additions & 0 deletions tests/unit/aiplatform/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,6 +390,7 @@ def test_wrapped_client():
)


@pytest.mark.usefixtures("google_auth_mock")
def test_client_w_override_default_version():

test_client_info = gapic_v1.client_info.ClientInfo()
Expand All @@ -407,6 +408,7 @@ def test_client_w_override_default_version():
)


@pytest.mark.usefixtures("google_auth_mock")
def test_client_w_override_select_version():

test_client_info = gapic_v1.client_info.ClientInfo()
Expand Down

0 comments on commit d06b22d

Please sign in to comment.