diff --git a/.github/workflows/mock_server_tests.yaml b/.github/workflows/mock_server_tests.yaml index 2da5320071..e93ac9905c 100644 --- a/.github/workflows/mock_server_tests.yaml +++ b/.github/workflows/mock_server_tests.yaml @@ -5,7 +5,7 @@ on: pull_request: name: Run Spanner tests against an in-mem mock server jobs: - system-tests: + mock-server-tests: runs-on: ubuntu-latest steps: diff --git a/.github/workflows/presubmit.yaml b/.github/workflows/presubmit.yaml new file mode 100644 index 0000000000..2d6132bd97 --- /dev/null +++ b/.github/workflows/presubmit.yaml @@ -0,0 +1,42 @@ +on: + push: + branches: + - main + pull_request: +name: Presubmit checks +permissions: + contents: read + pull-requests: write +jobs: + lint: + runs-on: ubuntu-latest + + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: 3.8 + - name: Install nox + run: python -m pip install nox + - name: Check formatting + run: nox -s lint + units: + runs-on: ubuntu-latest + strategy: + fail-fast: false + matrix: + python: ["3.8", "3.9", "3.10", "3.11", "3.12", "3.13"] + + steps: + - name: Checkout code + uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: ${{matrix.python}} + - name: Install nox + run: python -m pip install nox + - name: Run unit tests + run: nox -s unit-${{matrix.python}} diff --git a/.kokoro/presubmit/presubmit.cfg b/.kokoro/presubmit/presubmit.cfg index b158096f0a..14db9152d9 100644 --- a/.kokoro/presubmit/presubmit.cfg +++ b/.kokoro/presubmit/presubmit.cfg @@ -1,7 +1,7 @@ # Format: //devtools/kokoro/config/proto/build.proto -# Disable system tests. +# Only run a subset of all nox sessions env_vars: { - key: "RUN_SYSTEM_TESTS" - value: "false" + key: "NOX_SESSION" + value: "unit-3.8 unit-3.12 cover docs docfx" } diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 059e2a70df..4617e93bef 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -17,6 +17,8 @@ from google.api_core.exceptions import Aborted from google.api_core.gapic_v1.client_info import ClientInfo +from google.auth.credentials import AnonymousCredentials + from google.cloud import spanner_v1 as spanner from google.cloud.spanner_dbapi import partition_helper from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode, BatchDmlExecutor @@ -784,11 +786,15 @@ def connect( route_to_leader_enabled=route_to_leader_enabled, ) else: + client_options = None + if isinstance(credentials, AnonymousCredentials): + client_options = kwargs.get("client_options") client = spanner.Client( project=project, credentials=credentials, client_info=client_info, route_to_leader_enabled=route_to_leader_enabled, + client_options=client_options, ) else: if project is not None and client.project != project: diff --git a/google/cloud/spanner_dbapi/transaction_helper.py b/google/cloud/spanner_dbapi/transaction_helper.py index f8f5bfa584..744aeb7b43 100644 --- a/google/cloud/spanner_dbapi/transaction_helper.py +++ b/google/cloud/spanner_dbapi/transaction_helper.py @@ -162,7 +162,7 @@ def add_execute_statement_for_retry( self._last_statement_details_per_cursor[cursor] = last_statement_result_details self._statement_result_details_list.append(last_statement_result_details) - def retry_transaction(self): + def retry_transaction(self, default_retry_delay=None): """Retry the aborted transaction. All the statements executed in the original transaction @@ -202,7 +202,9 @@ def retry_transaction(self): raise RetryAborted(RETRY_ABORTED_ERROR, ex) return except Aborted as ex: - delay = _get_retry_delay(ex.errors[0], attempt) + delay = _get_retry_delay( + ex.errors[0], attempt, default_retry_delay=default_retry_delay + ) if delay: time.sleep(delay) diff --git a/google/cloud/spanner_v1/_helpers.py b/google/cloud/spanner_v1/_helpers.py index 7fa792a5f0..e76284864b 100644 --- a/google/cloud/spanner_v1/_helpers.py +++ b/google/cloud/spanner_v1/_helpers.py @@ -510,6 +510,7 @@ def _metadata_with_prefix(prefix, **kw): def _retry_on_aborted_exception( func, deadline, + default_retry_delay=None, ): """ Handles retry logic for Aborted exceptions, considering the deadline. @@ -520,7 +521,12 @@ def _retry_on_aborted_exception( attempts += 1 return func() except Aborted as exc: - _delay_until_retry(exc, deadline=deadline, attempts=attempts) + _delay_until_retry( + exc, + deadline=deadline, + attempts=attempts, + default_retry_delay=default_retry_delay, + ) continue @@ -608,7 +614,7 @@ def _metadata_with_span_context(metadata: List[Tuple[str, str]], **kw) -> None: inject(setter=OpenTelemetryContextSetter(), carrier=metadata) -def _delay_until_retry(exc, deadline, attempts): +def _delay_until_retry(exc, deadline, attempts, default_retry_delay=None): """Helper for :meth:`Session.run_in_transaction`. Detect retryable abort, and impose server-supplied delay. @@ -628,7 +634,7 @@ def _delay_until_retry(exc, deadline, attempts): if now >= deadline: raise - delay = _get_retry_delay(cause, attempts) + delay = _get_retry_delay(cause, attempts, default_retry_delay=default_retry_delay) if delay is not None: if now + delay > deadline: raise @@ -636,7 +642,7 @@ def _delay_until_retry(exc, deadline, attempts): time.sleep(delay) -def _get_retry_delay(cause, attempts): +def _get_retry_delay(cause, attempts, default_retry_delay=None): """Helper for :func:`_delay_until_retry`. :type exc: :class:`grpc.Call` @@ -658,6 +664,8 @@ def _get_retry_delay(cause, attempts): retry_info.ParseFromString(retry_info_pb) nanos = retry_info.retry_delay.nanos return retry_info.retry_delay.seconds + nanos / 1.0e9 + if default_retry_delay is not None: + return default_retry_delay return 2**attempts + random.random() diff --git a/google/cloud/spanner_v1/batch.py b/google/cloud/spanner_v1/batch.py index 39e29d4d41..3d632c7568 100644 --- a/google/cloud/spanner_v1/batch.py +++ b/google/cloud/spanner_v1/batch.py @@ -257,9 +257,11 @@ def commit( deadline = time.time() + kwargs.get( "timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS ) + default_retry_delay = kwargs.get("default_retry_delay", None) response = _retry_on_aborted_exception( method, deadline=deadline, + default_retry_delay=default_retry_delay, ) self.committed = response.commit_timestamp self.commit_stats = response.commit_stats diff --git a/google/cloud/spanner_v1/client.py b/google/cloud/spanner_v1/client.py index e201f93e9b..c006b965cf 100644 --- a/google/cloud/spanner_v1/client.py +++ b/google/cloud/spanner_v1/client.py @@ -241,7 +241,9 @@ def __init__( meter_provider = MeterProvider( metric_readers=[ PeriodicExportingMetricReader( - CloudMonitoringMetricsExporter(), + CloudMonitoringMetricsExporter( + project_id=project, credentials=credentials + ), export_interval_millis=METRIC_EXPORT_INTERVAL_MS, ) ] diff --git a/google/cloud/spanner_v1/metrics/metrics_exporter.py b/google/cloud/spanner_v1/metrics/metrics_exporter.py index e10cf6a2f1..68da08b400 100644 --- a/google/cloud/spanner_v1/metrics/metrics_exporter.py +++ b/google/cloud/spanner_v1/metrics/metrics_exporter.py @@ -26,6 +26,7 @@ from typing import Optional, List, Union, NoReturn, Tuple, Dict import google.auth +from google.auth import credentials as ga_credentials from google.api.distribution_pb2 import ( # pylint: disable=no-name-in-module Distribution, ) @@ -111,6 +112,7 @@ def __init__( self, project_id: Optional[str] = None, client: Optional["MetricServiceClient"] = None, + credentials: Optional[ga_credentials.Credentials] = None, ): """Initialize a custom exporter to send metrics for the Spanner Service Metrics.""" # Default preferred_temporality is all CUMULATIVE so need to customize @@ -121,6 +123,7 @@ def __init__( transport=MetricServiceGrpcTransport( channel=MetricServiceGrpcTransport.create_channel( options=_OPTIONS, + credentials=credentials, ) ) ) diff --git a/google/cloud/spanner_v1/session.py b/google/cloud/spanner_v1/session.py index f18ba57582..d5feb2ef1a 100644 --- a/google/cloud/spanner_v1/session.py +++ b/google/cloud/spanner_v1/session.py @@ -461,6 +461,7 @@ def run_in_transaction(self, func, *args, **kw): reraises any non-ABORT exceptions raised by ``func``. """ deadline = time.time() + kw.pop("timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS) + default_retry_delay = kw.pop("default_retry_delay", None) commit_request_options = kw.pop("commit_request_options", None) max_commit_delay = kw.pop("max_commit_delay", None) transaction_tag = kw.pop("transaction_tag", None) @@ -502,7 +503,11 @@ def run_in_transaction(self, func, *args, **kw): except Aborted as exc: del self._transaction if span: - delay_seconds = _get_retry_delay(exc.errors[0], attempts) + delay_seconds = _get_retry_delay( + exc.errors[0], + attempts, + default_retry_delay=default_retry_delay, + ) attributes = dict(delay_seconds=delay_seconds, cause=str(exc)) attributes.update(span_attributes) add_span_event( @@ -511,7 +516,9 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) - _delay_until_retry(exc, deadline, attempts) + _delay_until_retry( + exc, deadline, attempts, default_retry_delay=default_retry_delay + ) continue except GoogleAPICallError: del self._transaction @@ -539,7 +546,11 @@ def run_in_transaction(self, func, *args, **kw): except Aborted as exc: del self._transaction if span: - delay_seconds = _get_retry_delay(exc.errors[0], attempts) + delay_seconds = _get_retry_delay( + exc.errors[0], + attempts, + default_retry_delay=default_retry_delay, + ) attributes = dict(delay_seconds=delay_seconds) attributes.update(span_attributes) add_span_event( @@ -548,7 +559,9 @@ def run_in_transaction(self, func, *args, **kw): attributes, ) - _delay_until_retry(exc, deadline, attempts) + _delay_until_retry( + exc, deadline, attempts, default_retry_delay=default_retry_delay + ) except GoogleAPICallError: del self._transaction add_span_event( diff --git a/noxfile.py b/noxfile.py index 73ad757240..be3a05c455 100644 --- a/noxfile.py +++ b/noxfile.py @@ -181,21 +181,6 @@ def install_unittest_dependencies(session, *constraints): # XXX: Dump installed versions to debug OT issue session.run("pip", "list") - # Run py.test against the unit tests with OpenTelemetry. - session.run( - "py.test", - "--quiet", - "--cov=google.cloud.spanner", - "--cov=google.cloud", - "--cov=tests.unit", - "--cov-append", - "--cov-config=.coveragerc", - "--cov-report=", - "--cov-fail-under=0", - os.path.join("tests", "unit"), - *session.posargs, - ) - @nox.session(python=UNIT_TEST_PYTHON_VERSIONS) @nox.parametrize( @@ -329,9 +314,12 @@ def system(session, protobuf_implementation, database_dialect): session.skip( "Credentials or emulator host must be set via environment variable" ) - # If POSTGRESQL tests and Emulator, skip the tests - if os.environ.get("SPANNER_EMULATOR_HOST") and database_dialect == "POSTGRESQL": - session.skip("Postgresql is not supported by Emulator yet.") + if not ( + os.environ.get("SPANNER_EMULATOR_HOST") or protobuf_implementation == "python" + ): + session.skip( + "Only run system tests on real Spanner with one protobuf implementation to speed up the build" + ) # Install pyopenssl for mTLS testing. if os.environ.get("GOOGLE_API_USE_CLIENT_CERTIFICATE", "false") == "true": @@ -365,7 +353,7 @@ def system(session, protobuf_implementation, database_dialect): "SKIP_BACKUP_TESTS": "true", }, ) - if system_test_folder_exists: + elif system_test_folder_exists: session.run( "py.test", "--quiet", @@ -567,30 +555,32 @@ def prerelease_deps(session, protobuf_implementation, database_dialect): system_test_path = os.path.join("tests", "system.py") system_test_folder_path = os.path.join("tests", "system") - # Only run system tests if found. - if os.path.exists(system_test_path): - session.run( - "py.test", - "--verbose", - f"--junitxml=system_{session.python}_sponge_log.xml", - system_test_path, - *session.posargs, - env={ - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", - }, - ) - if os.path.exists(system_test_folder_path): - session.run( - "py.test", - "--verbose", - f"--junitxml=system_{session.python}_sponge_log.xml", - system_test_folder_path, - *session.posargs, - env={ - "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, - "SPANNER_DATABASE_DIALECT": database_dialect, - "SKIP_BACKUP_TESTS": "true", - }, - ) + # Only run system tests for one protobuf implementation on real Spanner to speed up the build. + if os.environ.get("SPANNER_EMULATOR_HOST") or protobuf_implementation == "python": + # Only run system tests if found. + if os.path.exists(system_test_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + "SPANNER_DATABASE_DIALECT": database_dialect, + "SKIP_BACKUP_TESTS": "true", + }, + ) + elif os.path.exists(system_test_folder_path): + session.run( + "py.test", + "--verbose", + f"--junitxml=system_{session.python}_sponge_log.xml", + system_test_folder_path, + *session.posargs, + env={ + "PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION": protobuf_implementation, + "SPANNER_DATABASE_DIALECT": database_dialect, + "SKIP_BACKUP_TESTS": "true", + }, + ) diff --git a/tests/unit/spanner_dbapi/test_connect.py b/tests/unit/spanner_dbapi/test_connect.py index 47d8b4f6a5..b3314fe2bc 100644 --- a/tests/unit/spanner_dbapi/test_connect.py +++ b/tests/unit/spanner_dbapi/test_connect.py @@ -17,8 +17,8 @@ import unittest from unittest import mock -import google.auth.credentials - +import google +from google.auth.credentials import AnonymousCredentials INSTANCE = "test-instance" DATABASE = "test-database" @@ -45,7 +45,13 @@ def test_w_implicit(self, mock_client): instance = client.instance.return_value database = instance.database.return_value - connection = connect(INSTANCE, DATABASE) + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) self.assertIsInstance(connection, Connection) @@ -55,6 +61,7 @@ def test_w_implicit(self, mock_client): project=mock.ANY, credentials=mock.ANY, client_info=mock.ANY, + client_options=mock.ANY, route_to_leader_enabled=True, ) @@ -92,6 +99,7 @@ def test_w_explicit(self, mock_client): project=PROJECT, credentials=credentials, client_info=mock.ANY, + client_options=mock.ANY, route_to_leader_enabled=False, ) client_info = mock_client.call_args_list[0][1]["client_info"] diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 4bee9e93c7..6f478dfe57 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -19,6 +19,7 @@ import unittest import warnings import pytest +from google.auth.credentials import AnonymousCredentials from google.cloud.spanner_admin_database_v1 import DatabaseDialect from google.cloud.spanner_dbapi.batch_dml_executor import BatchMode @@ -68,7 +69,11 @@ def _make_connection( from google.cloud.spanner_v1.client import Client # We don't need a real Client object to test the constructor - client = Client() + client = Client( + project="test", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) instance = Instance(INSTANCE, client=client) database = instance.database(DATABASE, database_dialect=database_dialect) return Connection(instance, database, **kwargs) @@ -239,7 +244,13 @@ def test_close(self): from google.cloud.spanner_dbapi import connect from google.cloud.spanner_dbapi import InterfaceError - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) self.assertFalse(connection.is_closed) @@ -830,7 +841,12 @@ def test_invalid_custom_client_connection(self): def test_connection_wo_database(self): from google.cloud.spanner_dbapi import connect - connection = connect("test-instance") + connection = connect( + "test-instance", + credentials=AnonymousCredentials(), + project="test-project", + client_options={"api_endpoint": "none"}, + ) self.assertTrue(connection.database is None) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 2a8cddac9b..b96e8c1444 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -16,6 +16,8 @@ from unittest import mock import sys import unittest + +from google.auth.credentials import AnonymousCredentials from google.rpc.code_pb2 import ABORTED from google.cloud.spanner_dbapi.parsed_statement import ( @@ -127,7 +129,13 @@ def test_do_batch_update(self): sql = "DELETE FROM table WHERE col1 = %s" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) connection.autocommit = True transaction = self._transaction_mock(mock_response=[1, 1, 1]) @@ -479,7 +487,13 @@ def test_executemany_DLL(self, mock_client): def test_executemany_client_statement(self): from google.cloud.spanner_dbapi import connect, ProgrammingError - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) cursor = connection.cursor() @@ -497,7 +511,13 @@ def test_executemany(self, mock_client): operation = """SELECT * FROM table1 WHERE "col1" = @a1""" params_seq = ((1,), (2,)) - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) cursor = connection.cursor() cursor._result_set = [1, 2, 3] @@ -519,7 +539,13 @@ def test_executemany_delete_batch_autocommit(self): sql = "DELETE FROM table WHERE col1 = %s" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) connection.autocommit = True transaction = self._transaction_mock() @@ -551,7 +577,13 @@ def test_executemany_update_batch_autocommit(self): sql = "UPDATE table SET col1 = %s WHERE col2 = %s" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) connection.autocommit = True transaction = self._transaction_mock() @@ -595,7 +627,13 @@ def test_executemany_insert_batch_non_autocommit(self): sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) transaction = self._transaction_mock() @@ -632,7 +670,13 @@ def test_executemany_insert_batch_autocommit(self): sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) connection.autocommit = True @@ -676,7 +720,13 @@ def test_executemany_insert_batch_failed(self): sql = """INSERT INTO table (col1, "col2", `col3`, `"col4"`) VALUES (%s, %s, %s, %s)""" err_details = "Details here" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) connection.autocommit = True cursor = connection.cursor() @@ -705,7 +755,13 @@ def test_executemany_insert_batch_aborted(self): args = [(1, 2, 3, 4), (5, 6, 7, 8)] err_details = "Aborted details here" - connection = connect("test-instance", "test-database") + connection = connect( + "test-instance", + "test-database", + project="test-project", + credentials=AnonymousCredentials(), + client_options={"api_endpoint": "none"}, + ) transaction1 = mock.Mock() transaction1.batch_update = mock.Mock( diff --git a/tests/unit/spanner_dbapi/test_transaction_helper.py b/tests/unit/spanner_dbapi/test_transaction_helper.py index 1d50a51825..958fca0ce6 100644 --- a/tests/unit/spanner_dbapi/test_transaction_helper.py +++ b/tests/unit/spanner_dbapi/test_transaction_helper.py @@ -323,7 +323,7 @@ def test_retry_transaction_aborted_retry(self): None, ] - self._under_test.retry_transaction() + self._under_test.retry_transaction(default_retry_delay=0) run_mock.assert_has_calls( ( diff --git a/tests/unit/test__helpers.py b/tests/unit/test__helpers.py index 7010affdd2..d29f030e55 100644 --- a/tests/unit/test__helpers.py +++ b/tests/unit/test__helpers.py @@ -15,6 +15,7 @@ import unittest import mock + from google.cloud.spanner_v1 import TransactionOptions @@ -824,7 +825,7 @@ def test_retry_on_error(self): True, ] - _retry(functools.partial(test_api.test_fxn)) + _retry(functools.partial(test_api.test_fxn), delay=0) self.assertEqual(test_api.test_fxn.call_count, 3) @@ -844,6 +845,7 @@ def test_retry_allowed_exceptions(self): _retry( functools.partial(test_api.test_fxn), allowed_exceptions={NotFound: None}, + delay=0, ) self.assertEqual(test_api.test_fxn.call_count, 2) @@ -860,7 +862,7 @@ def test_retry_count(self): ] with self.assertRaises(InternalServerError): - _retry(functools.partial(test_api.test_fxn), retry_count=1) + _retry(functools.partial(test_api.test_fxn), retry_count=1, delay=0) self.assertEqual(test_api.test_fxn.call_count, 2) @@ -879,6 +881,7 @@ def test_check_rst_stream_error(self): _retry( functools.partial(test_api.test_fxn), allowed_exceptions={InternalServerError: _check_rst_stream_error}, + delay=0, ) self.assertEqual(test_api.test_fxn.call_count, 3) @@ -896,7 +899,7 @@ def test_retry_on_aborted_exception_with_success_after_first_aborted_retry(self) ] deadline = time.time() + 30 result_after_retry = _retry_on_aborted_exception( - functools.partial(test_api.test_fxn), deadline + functools.partial(test_api.test_fxn), deadline, default_retry_delay=0 ) self.assertEqual(test_api.test_fxn.call_count, 2) @@ -910,16 +913,18 @@ def test_retry_on_aborted_exception_with_success_after_three_retries(self): test_api = mock.create_autospec(self.test_class) # Case where aborted exception is thrown after other generic exceptions + aborted = Aborted("aborted exception", errors=["Aborted error"]) test_api.test_fxn.side_effect = [ - Aborted("aborted exception", errors=("Aborted error")), - Aborted("aborted exception", errors=("Aborted error")), - Aborted("aborted exception", errors=("Aborted error")), + aborted, + aborted, + aborted, "true", ] deadline = time.time() + 30 _retry_on_aborted_exception( functools.partial(test_api.test_fxn), deadline=deadline, + default_retry_delay=0, ) self.assertEqual(test_api.test_fxn.call_count, 4) @@ -935,10 +940,12 @@ def test_retry_on_aborted_exception_raises_aborted_if_deadline_expires(self): Aborted("aborted exception", errors=("Aborted error")), "true", ] - deadline = time.time() + 0.1 + deadline = time.time() + 0.001 with self.assertRaises(Aborted): _retry_on_aborted_exception( - functools.partial(test_api.test_fxn), deadline=deadline + functools.partial(test_api.test_fxn), + deadline=deadline, + default_retry_delay=0.01, ) self.assertEqual(test_api.test_fxn.call_count, 1) diff --git a/tests/unit/test_batch.py b/tests/unit/test_batch.py index 2cea740ab6..355ce20520 100644 --- a/tests/unit/test_batch.py +++ b/tests/unit/test_batch.py @@ -277,17 +277,13 @@ def test_aborted_exception_on_commit_with_retries(self): # Assertion: Ensure that calling batch.commit() raises the Aborted exception with self.assertRaises(Aborted) as context: - batch.commit() + batch.commit(timeout_secs=0.1, default_retry_delay=0) # Verify additional details about the exception self.assertEqual(str(context.exception), "409 Transaction was aborted") self.assertGreater( api.commit.call_count, 1, "commit should be called more than once" ) - # Since we are using exponential backoff here and default timeout is set to 30 sec 2^x <= 30. So value for x will be 4 - self.assertEqual( - api.commit.call_count, 4, "commit should be called exactly 4 times" - ) def _test_commit_with_options( self, diff --git a/tests/unit/test_client.py b/tests/unit/test_client.py index a464209874..6084224a84 100644 --- a/tests/unit/test_client.py +++ b/tests/unit/test_client.py @@ -16,6 +16,8 @@ import os import mock +from google.auth.credentials import AnonymousCredentials + from google.cloud.spanner_v1 import DirectedReadOptions, DefaultTransactionOptions @@ -513,7 +515,7 @@ def test_list_instance_configs(self): from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsRequest from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsResponse - api = InstanceAdminClient() + api = InstanceAdminClient(credentials=AnonymousCredentials()) credentials = _make_credentials() client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -560,8 +562,8 @@ def test_list_instance_configs_w_options(self): from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsRequest from google.cloud.spanner_admin_instance_v1 import ListInstanceConfigsResponse - api = InstanceAdminClient() credentials = _make_credentials() + api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -636,8 +638,8 @@ def test_list_instances(self): from google.cloud.spanner_admin_instance_v1 import ListInstancesRequest from google.cloud.spanner_admin_instance_v1 import ListInstancesResponse - api = InstanceAdminClient() credentials = _make_credentials() + api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api @@ -684,8 +686,8 @@ def test_list_instances_w_options(self): from google.cloud.spanner_admin_instance_v1 import ListInstancesRequest from google.cloud.spanner_admin_instance_v1 import ListInstancesResponse - api = InstanceAdminClient() credentials = _make_credentials() + api = InstanceAdminClient(credentials=credentials) client = self._make_one(project=self.PROJECT, credentials=credentials) client._instance_admin_api = api diff --git a/tests/unit/test_database.py b/tests/unit/test_database.py index c270a0944a..c7ed5a0e3d 100644 --- a/tests/unit/test_database.py +++ b/tests/unit/test_database.py @@ -1916,7 +1916,7 @@ def test_context_mgr_w_aborted_commit_status(self): pool = database._pool = _Pool() session = _Session(database) pool.put(session) - checkout = self._make_one(database) + checkout = self._make_one(database, timeout_secs=0.1, default_retry_delay=0) with self.assertRaises(Aborted): with checkout as batch: @@ -1935,9 +1935,7 @@ def test_context_mgr_w_aborted_commit_status(self): return_commit_stats=True, request_options=RequestOptions(), ) - # Asserts that the exponential backoff retry for aborted transactions with a 30-second deadline - # allows for a maximum of 4 retries (2^x <= 30) to stay within the time limit. - self.assertEqual(api.commit.call_count, 4) + self.assertGreater(api.commit.call_count, 1) api.commit.assert_any_call( request=request, metadata=[ diff --git a/tests/unit/test_instance.py b/tests/unit/test_instance.py index e7ad729438..f3bf6726c0 100644 --- a/tests/unit/test_instance.py +++ b/tests/unit/test_instance.py @@ -14,6 +14,8 @@ import unittest import mock +from google.auth.credentials import AnonymousCredentials + from google.cloud.spanner_v1 import DefaultTransactionOptions @@ -586,7 +588,7 @@ def test_list_databases(self): from google.cloud.spanner_admin_database_v1 import ListDatabasesRequest from google.cloud.spanner_admin_database_v1 import ListDatabasesResponse - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -625,7 +627,7 @@ def test_list_databases_w_options(self): from google.cloud.spanner_admin_database_v1 import ListDatabasesRequest from google.cloud.spanner_admin_database_v1 import ListDatabasesResponse - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -704,7 +706,7 @@ def test_list_backups_defaults(self): from google.cloud.spanner_admin_database_v1 import ListBackupsRequest from google.cloud.spanner_admin_database_v1 import ListBackupsResponse - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -743,7 +745,7 @@ def test_list_backups_w_options(self): from google.cloud.spanner_admin_database_v1 import ListBackupsRequest from google.cloud.spanner_admin_database_v1 import ListBackupsResponse - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -787,7 +789,7 @@ def test_list_backup_operations_defaults(self): from google.longrunning import operations_pb2 from google.protobuf.any_pb2 import Any - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -832,7 +834,7 @@ def test_list_backup_operations_w_options(self): from google.longrunning import operations_pb2 from google.protobuf.any_pb2 import Any - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -884,7 +886,7 @@ def test_list_database_operations_defaults(self): from google.longrunning import operations_pb2 from google.protobuf.any_pb2 import Any - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) @@ -941,7 +943,7 @@ def test_list_database_operations_w_options(self): from google.longrunning import operations_pb2 from google.protobuf.any_pb2 import Any - api = DatabaseAdminClient() + api = DatabaseAdminClient(credentials=AnonymousCredentials()) client = _Client(self.PROJECT) client.database_admin_api = api instance = self._make_one(self.INSTANCE_ID, client) diff --git a/tests/unit/test_metrics.py b/tests/unit/test_metrics.py index bb2695553b..59fe6d2f61 100644 --- a/tests/unit/test_metrics.py +++ b/tests/unit/test_metrics.py @@ -15,6 +15,9 @@ import pytest from unittest.mock import MagicMock from google.api_core.exceptions import ServiceUnavailable +from google.auth import exceptions +from google.auth.credentials import Credentials + from google.cloud.spanner_v1.client import Client from unittest.mock import patch from grpc._interceptor import _UnaryOutcome @@ -28,6 +31,26 @@ # pytest.importorskip("opentelemetry.semconv.attributes.otel_attributes") +class TestCredentials(Credentials): + @property + def expired(self): + return False + + @property + def valid(self): + return True + + def refresh(self, request): + raise exceptions.InvalidOperation("Anonymous credentials cannot be refreshed.") + + def apply(self, headers, token=None): + if token is not None: + raise exceptions.InvalidValue("Anonymous credentials don't support tokens.") + + def before_request(self, request, method, url, headers): + """Anonymous credentials do nothing to the request.""" + + @pytest.fixture(autouse=True) def patched_client(monkeypatch): monkeypatch.setenv("SPANNER_ENABLE_BUILTIN_METRICS", "true") @@ -37,7 +60,11 @@ def patched_client(monkeypatch): if SpannerMetricsTracerFactory._metrics_tracer_factory is not None: SpannerMetricsTracerFactory._metrics_tracer_factory = None - client = Client() + client = Client( + project="test", + credentials=TestCredentials(), + # client_options={"api_endpoint": "none"} + ) yield client # Resetting diff --git a/tests/unit/test_metrics_exporter.py b/tests/unit/test_metrics_exporter.py index 62fb531345..f57984ec66 100644 --- a/tests/unit/test_metrics_exporter.py +++ b/tests/unit/test_metrics_exporter.py @@ -14,6 +14,9 @@ import unittest from unittest.mock import patch, MagicMock, Mock + +from google.auth.credentials import AnonymousCredentials + from google.cloud.spanner_v1.metrics.metrics_exporter import ( CloudMonitoringMetricsExporter, _normalize_label_key, @@ -74,10 +77,6 @@ def setUp(self): unit="counts", ) - def test_default_ctor(self): - exporter = CloudMonitoringMetricsExporter() - self.assertIsNotNone(exporter.project_id) - def test_normalize_label_key(self): """Test label key normalization""" test_cases = [ @@ -236,7 +235,9 @@ def test_metric_timeseries_conversion(self): metrics = self.metric_reader.get_metrics_data() self.assertTrue(metrics is not None) - exporter = CloudMonitoringMetricsExporter(PROJECT_ID) + exporter = CloudMonitoringMetricsExporter( + PROJECT_ID, credentials=AnonymousCredentials() + ) timeseries = exporter._resource_metrics_to_timeseries_pb(metrics) # Both counter values should be summed together @@ -257,7 +258,9 @@ def test_metric_timeseries_scope_filtering(self): # Export metrics metrics = self.metric_reader.get_metrics_data() - exporter = CloudMonitoringMetricsExporter(PROJECT_ID) + exporter = CloudMonitoringMetricsExporter( + PROJECT_ID, credentials=AnonymousCredentials() + ) timeseries = exporter._resource_metrics_to_timeseries_pb(metrics) # Metris with incorrect sope should be filtered out @@ -342,7 +345,9 @@ def test_export_early_exit_if_extras_not_installed(self): with self.assertLogs( "google.cloud.spanner_v1.metrics.metrics_exporter", level="WARNING" ) as log: - exporter = CloudMonitoringMetricsExporter(PROJECT_ID) + exporter = CloudMonitoringMetricsExporter( + PROJECT_ID, credentials=AnonymousCredentials() + ) self.assertFalse(exporter.export([])) self.assertIn( "WARNING:google.cloud.spanner_v1.metrics.metrics_exporter:Metric exporter called without dependencies installed.", @@ -382,12 +387,16 @@ def test_export(self): def test_force_flush(self): """Verify that the unimplemented force flush can be called.""" - exporter = CloudMonitoringMetricsExporter(PROJECT_ID) + exporter = CloudMonitoringMetricsExporter( + PROJECT_ID, credentials=AnonymousCredentials() + ) self.assertTrue(exporter.force_flush()) def test_shutdown(self): """Verify that the unimplemented shutdown can be called.""" - exporter = CloudMonitoringMetricsExporter() + exporter = CloudMonitoringMetricsExporter( + project_id="test", credentials=AnonymousCredentials() + ) try: exporter.shutdown() except Exception as e: @@ -409,7 +418,9 @@ def test_metrics_to_time_series_empty_input( self, mocked_data_point_to_timeseries_pb ): """Verify that metric entries with no timeseries data do not return a time series entry.""" - exporter = CloudMonitoringMetricsExporter() + exporter = CloudMonitoringMetricsExporter( + project_id="test", credentials=AnonymousCredentials() + ) data_point = Mock() metric = Mock(data_points=[data_point]) scope_metric = Mock( @@ -422,7 +433,9 @@ def test_metrics_to_time_series_empty_input( def test_to_point(self): """Verify conversion of datapoints.""" - exporter = CloudMonitoringMetricsExporter() + exporter = CloudMonitoringMetricsExporter( + project_id="test", credentials=AnonymousCredentials() + ) number_point = NumberDataPoint( attributes=[], start_time_unix_nano=0, time_unix_nano=0, value=9 diff --git a/tests/unit/test_pool.py b/tests/unit/test_pool.py index a9593b3651..768f8482f3 100644 --- a/tests/unit/test_pool.py +++ b/tests/unit/test_pool.py @@ -283,7 +283,7 @@ def test_spans_bind_get_empty_pool(self): return # Tests trying to invoke pool.get() from an empty pool. - pool = self._make_one(size=0) + pool = self._make_one(size=0, default_timeout=0.1) database = _Database("name") session1 = _Session(database) with trace_call("pool.Get", session1): diff --git a/tests/unit/test_session.py b/tests/unit/test_session.py index 8f5f7039b9..d72c01f5ab 100644 --- a/tests/unit/test_session.py +++ b/tests/unit/test_session.py @@ -1031,7 +1031,9 @@ def unit_of_work(txn, *args, **kw): txn.insert(TABLE_NAME, COLUMNS, VALUES) return "answer" - return_value = session.run_in_transaction(unit_of_work, "abc", some_arg="def") + return_value = session.run_in_transaction( + unit_of_work, "abc", some_arg="def", default_retry_delay=0 + ) self.assertEqual(len(called_with), 2) for index, (txn, args, kw) in enumerate(called_with): @@ -1858,7 +1860,7 @@ def _time_func(): # check if current time > deadline with mock.patch("time.time", _time_func): with self.assertRaises(Exception): - _delay_until_retry(exc_mock, 2, 1) + _delay_until_retry(exc_mock, 2, 1, default_retry_delay=0) with mock.patch("time.time", _time_func): with mock.patch(