diff --git a/google/cloud/bigtable/data/_async/_mutate_rows.py b/google/cloud/bigtable/data/_async/_mutate_rows.py index 7d1144553..40e956879 100644 --- a/google/cloud/bigtable/data/_async/_mutate_rows.py +++ b/google/cloud/bigtable/data/_async/_mutate_rows.py @@ -25,6 +25,7 @@ from google.cloud.bigtable.data._helpers import _make_metadata from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _retry_exception_factory +from google.cloud.bigtable.data._helpers import backoff_generator # mutate_rows requests are limited to this number of mutations from google.cloud.bigtable.data.mutations import _MUTATE_ROWS_REQUEST_MUTATION_LIMIT @@ -35,6 +36,7 @@ ) from google.cloud.bigtable.data.mutations import RowMutationEntry from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._metrics import ActiveOperationMetric @dataclass @@ -65,6 +67,7 @@ def __init__( mutation_entries: list["RowMutationEntry"], operation_timeout: float, attempt_timeout: float | None, + metrics: ActiveOperationMetric, retryable_exceptions: Sequence[type[Exception]] = (), ): """ @@ -75,6 +78,8 @@ def __init__( - operation_timeout: the timeout to use for the entire operation, in seconds. - attempt_timeout: the timeout to use for each mutate_rows attempt, in seconds. If not specified, the request will run until operation_timeout is reached. + - metrics: the metrics object to use for tracking the operation + - retryable_exceptions: a list of exceptions that should be retried """ # check that mutations are within limits total_mutations = sum(len(entry.mutations) for entry in mutation_entries) @@ -100,7 +105,7 @@ def __init__( # Entry level errors bt_exceptions._MutateRowsIncomplete, ) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + sleep_generator = backoff_generator(0.01, 2, 60) self._operation = retries.retry_target_async( self._run_attempt, self.is_retryable, @@ -115,6 +120,9 @@ def __init__( self.mutations = [_EntryWithProto(m, m._to_pb()) for m in mutation_entries] self.remaining_indices = list(range(len(self.mutations))) self.errors: dict[int, list[Exception]] = {} + # set up metrics + metrics.backoff_generator = sleep_generator + self._operation_metrics = metrics async def start(self): """ @@ -136,9 +144,11 @@ async def start(self): all_errors: list[Exception] = [] for idx, exc_list in self.errors.items(): if len(exc_list) == 0: - raise core_exceptions.ClientError( + exc = core_exceptions.ClientError( f"Mutation {idx} failed with no associated errors" ) + self._operation_metrics.end_with_status(exc) + raise exc elif len(exc_list) == 1: cause_exc = exc_list[0] else: @@ -148,9 +158,13 @@ async def start(self): bt_exceptions.FailedMutationEntryError(idx, entry, cause_exc) ) if all_errors: - raise bt_exceptions.MutationsExceptionGroup( + combined_exc = bt_exceptions.MutationsExceptionGroup( all_errors, len(self.mutations) ) + self._operation_metrics.end_with_status(combined_exc) + raise combined_exc + else: + self._operation_metrics.end_with_success() async def _run_attempt(self): """ @@ -161,6 +175,8 @@ async def _run_attempt(self): retry after the attempt is complete - GoogleAPICallError: if the gapic rpc fails """ + # register attempt start + self._operation_metrics.start_attempt() request_entries = [self.mutations[idx].proto for idx in self.remaining_indices] # track mutations in this request that have not been finalized yet active_request_indices = { @@ -177,34 +193,47 @@ async def _run_attempt(self): entries=request_entries, retry=None, ) - async for result_list in result_generator: - for result in result_list.entries: - # convert sub-request index to global index - orig_idx = active_request_indices[result.index] - entry_error = core_exceptions.from_grpc_status( - result.status.code, - result.status.message, - details=result.status.details, - ) - if result.status.code != 0: - # mutation failed; update error list (and remaining_indices if retryable) - self._handle_entry_error(orig_idx, entry_error) - elif orig_idx in self.errors: - # mutation succeeded; remove from error list - del self.errors[orig_idx] - # remove processed entry from active list - del active_request_indices[result.index] + try: + async for result_list in result_generator: + for result in result_list.entries: + # convert sub-request index to global index + orig_idx = active_request_indices[result.index] + entry_error = core_exceptions.from_grpc_status( + result.status.code, + result.status.message, + details=result.status.details, + ) + if result.status.code != 0: + # mutation failed; update error list (and remaining_indices if retryable) + self._handle_entry_error(orig_idx, entry_error) + elif orig_idx in self.errors: + # mutation succeeded; remove from error list + del self.errors[orig_idx] + # remove processed entry from active list + del active_request_indices[result.index] + finally: + # send trailing metadata to metrics + result_generator.cancel() + metadata = ( + await result_generator.trailing_metadata() + + await result_generator.initial_metadata() + ) + self._operation_metrics.add_response_metadata(metadata) except Exception as exc: # add this exception to list for each mutation that wasn't # already handled, and update remaining_indices if mutation is retryable for idx in active_request_indices.values(): self._handle_entry_error(idx, exc) + # record attempt failure metric + self._operation_metrics.end_attempt_with_status(exc) # bubble up exception to be handled by retry wrapper raise # check if attempt succeeded, or needs to be retried if self.remaining_indices: # unfinished work; raise exception to trigger retry - raise bt_exceptions._MutateRowsIncomplete + last_exc = self.errors[self.remaining_indices[-1]][-1] + self._operation_metrics.end_attempt_with_status(last_exc) + raise bt_exceptions._MutateRowsIncomplete() def _handle_entry_error(self, idx: int, exc: Exception): """ diff --git a/google/cloud/bigtable/data/_async/_read_rows.py b/google/cloud/bigtable/data/_async/_read_rows.py index 9e0fd78e1..41cfdeb5f 100644 --- a/google/cloud/bigtable/data/_async/_read_rows.py +++ b/google/cloud/bigtable/data/_async/_read_rows.py @@ -18,10 +18,10 @@ from typing import ( TYPE_CHECKING, AsyncGenerator, - AsyncIterable, Awaitable, Sequence, ) +import time from google.cloud.bigtable_v2.types import ReadRowsRequest as ReadRowsRequestPB from google.cloud.bigtable_v2.types import ReadRowsResponse as ReadRowsResponsePB @@ -34,13 +34,16 @@ from google.cloud.bigtable.data.exceptions import _RowSetComplete from google.cloud.bigtable.data._helpers import _attempt_timeout_generator from google.cloud.bigtable.data._helpers import _make_metadata +from google.cloud.bigtable.data._helpers import backoff_generator + +from google.api_core.grpc_helpers_async import GrpcAsyncStream from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.api_core import retry as retries -from google.api_core.retry import exponential_sleep_generator if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync + from google.cloud.bigtable.data._metrics import ActiveOperationMetric class _ResetRow(Exception): @@ -70,6 +73,7 @@ class _ReadRowsOperationAsync: "_metadata", "_last_yielded_row_key", "_remaining_count", + "_operation_metrics", ) def __init__( @@ -78,6 +82,7 @@ def __init__( table: "TableAsync", operation_timeout: float, attempt_timeout: float, + metrics: ActiveOperationMetric, retryable_exceptions: Sequence[type[Exception]] = (), ): self.attempt_timeout_gen = _attempt_timeout_generator( @@ -100,17 +105,26 @@ def __init__( ) self._last_yielded_row_key: bytes | None = None self._remaining_count: int | None = self.request.rows_limit or None + self._operation_metrics = metrics def start_operation(self) -> AsyncGenerator[Row, None]: """ Start the read_rows operation, retrying on retryable errors. """ + sleep_generator = backoff_generator() + self._operation_metrics.backoff_generator = sleep_generator + + # Metrics: + # track attempt failures using build_wrapped_fn_handlers() for raised exceptions + # and operation timeouts + metric_fns = self._operation_metrics.build_wrapped_fn_handlers(self._predicate) + metric_predicate, metric_exc_factory = metric_fns return retries.retry_target_stream_async( self._read_rows_attempt, - self._predicate, - exponential_sleep_generator(0.01, 60, multiplier=2), + metric_fns[0], + sleep_generator, self.operation_timeout, - exception_factory=_retry_exception_factory, + exception_factory=metric_fns[1], ) def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: @@ -120,6 +134,8 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: which will call this function until it succeeds or a non-retryable error is raised. """ + # register metric start + self._operation_metrics.start_attempt() # revise request keys and ranges between attempts if self._last_yielded_row_key is not None: # if this is a retry, try to trim down the request to avoid ones we've already processed @@ -130,12 +146,12 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: ) except _RowSetComplete: # if we've already seen all the rows, we're done - return self.merge_rows(None) + return self.merge_rows(None, self._operation_metrics) # revise the limit based on number of rows already yielded if self._remaining_count is not None: self.request.rows_limit = self._remaining_count if self._remaining_count == 0: - return self.merge_rows(None) + return self.merge_rows(None, self._operation_metrics) # create and return a new row merger gapic_stream = self.table.client._gapic_client.read_rows( self.request, @@ -144,70 +160,82 @@ def _read_rows_attempt(self) -> AsyncGenerator[Row, None]: retry=None, ) chunked_stream = self.chunk_stream(gapic_stream) - return self.merge_rows(chunked_stream) + return self.merge_rows(chunked_stream, self._operation_metrics) async def chunk_stream( - self, stream: Awaitable[AsyncIterable[ReadRowsResponsePB]] + self, stream: Awaitable[GrpcAsyncStream[ReadRowsResponsePB]] ) -> AsyncGenerator[ReadRowsResponsePB.CellChunk, None]: """ process chunks out of raw read_rows stream """ - async for resp in await stream: - # extract proto from proto-plus wrapper - resp = resp._pb + call = await stream + try: + async for resp in call: + # extract proto from proto-plus wrapper + resp = resp._pb + + # handle last_scanned_row_key packets, sent when server + # has scanned past the end of the row range + if resp.last_scanned_row_key: + if ( + self._last_yielded_row_key is not None + and resp.last_scanned_row_key <= self._last_yielded_row_key + ): + raise InvalidChunk("last scanned out of order") + self._last_yielded_row_key = resp.last_scanned_row_key - # handle last_scanned_row_key packets, sent when server - # has scanned past the end of the row range - if resp.last_scanned_row_key: - if ( - self._last_yielded_row_key is not None - and resp.last_scanned_row_key <= self._last_yielded_row_key - ): - raise InvalidChunk("last scanned out of order") - self._last_yielded_row_key = resp.last_scanned_row_key - - current_key = None - # process each chunk in the response - for c in resp.chunks: - if current_key is None: - current_key = c.row_key + current_key = None + # process each chunk in the response + for c in resp.chunks: if current_key is None: - raise InvalidChunk("first chunk is missing a row key") - elif ( - self._last_yielded_row_key - and current_key <= self._last_yielded_row_key - ): - raise InvalidChunk("row keys should be strictly increasing") + current_key = c.row_key + if current_key is None: + raise InvalidChunk("first chunk is missing a row key") + elif ( + self._last_yielded_row_key + and current_key <= self._last_yielded_row_key + ): + raise InvalidChunk("row keys should be strictly increasing") - yield c + yield c - if c.reset_row: - current_key = None - elif c.commit_row: - # update row state after each commit - self._last_yielded_row_key = current_key - if self._remaining_count is not None: - self._remaining_count -= 1 - if self._remaining_count < 0: - raise InvalidChunk("emit count exceeds row limit") - current_key = None + if c.reset_row: + current_key = None + elif c.commit_row: + # update row state after each commit + self._last_yielded_row_key = current_key + if self._remaining_count is not None: + self._remaining_count -= 1 + if self._remaining_count < 0: + raise InvalidChunk("emit count exceeds row limit") + current_key = None + finally: + # ensure stream is closed + call.cancel() + # send trailing metadata to metrics + metadata = await call.trailing_metadata() + await call.initial_metadata() + self._operation_metrics.add_response_metadata(metadata) @staticmethod async def merge_rows( - chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None + chunks: AsyncGenerator[ReadRowsResponsePB.CellChunk, None] | None, + operation: ActiveOperationMetric, ): """ Merge chunks into rows """ if chunks is None: + operation.end_with_success() return it = chunks.__aiter__() + is_first_row = True # For each row while True: try: c = await it.__anext__() except StopAsyncIteration: # stream complete + operation.end_with_success() return row_key = c.row_key @@ -284,7 +312,17 @@ async def merge_rows( Cell(value, row_key, family, qualifier, ts, list(labels)) ) if c.commit_row: + if is_first_row: + # record first row latency in metrics + is_first_row = False + operation.attempt_first_response() + block_time = time.monotonic() yield Row(row_key, cells) + # most metric operations use setters, but this one updates + # the value directly to avoid extra overhead + operation.active_attempt.application_blocking_time_ms += ( # type: ignore + time.monotonic() - block_time + ) * 1000 break c = await it.__anext__() except _ResetRow as e: diff --git a/google/cloud/bigtable/data/_async/client.py b/google/cloud/bigtable/data/_async/client.py index 6fbc93d50..bcc57c308 100644 --- a/google/cloud/bigtable/data/_async/client.py +++ b/google/cloud/bigtable/data/_async/client.py @@ -71,6 +71,7 @@ from google.cloud.bigtable.data._helpers import _get_retryable_errors from google.cloud.bigtable.data._helpers import _get_timeouts from google.cloud.bigtable.data._helpers import _attempt_timeout_generator +from google.cloud.bigtable.data._helpers import backoff_generator from google.cloud.bigtable.data._async.mutations_batcher import MutationsBatcherAsync from google.cloud.bigtable.data._async.mutations_batcher import _MB_SIZE from google.cloud.bigtable.data.read_modify_write_rules import ReadModifyWriteRule @@ -80,6 +81,7 @@ from google.cloud.bigtable.data.row_filters import RowFilterChain from google.cloud.bigtable.data._metrics import BigtableClientSideMetricsController +from google.cloud.bigtable.data._metrics import OperationType if TYPE_CHECKING: @@ -540,7 +542,6 @@ def __init__( table_id=table_id, app_profile_id=app_profile_id, ) - self.default_read_rows_retryable_errors = ( default_read_rows_retryable_errors or () ) @@ -567,6 +568,7 @@ async def read_rows_stream( attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + **kwargs, ) -> AsyncIterable[Row]: """ Read a set of rows from the table, based on the specified query. @@ -594,7 +596,7 @@ async def read_rows_stream( - an asynchronous iterator that yields rows returned by the query Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - GoogleAPIError: raised if the request encounters an unrecoverable error """ @@ -603,11 +605,20 @@ async def read_rows_stream( ) retryable_excs = _get_retryable_errors(retryable_errors, self) + # extract metric operation if passed down through kwargs + # used so that read_row can disable is_streaming flag + metric_operation = kwargs.pop("metric_operation", None) + if metric_operation is None: + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=True + ) + row_merger = _ReadRowsOperationAsync( query, self, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metrics=metric_operation, retryable_exceptions=retryable_excs, ) return row_merger.start_operation() @@ -620,6 +631,7 @@ async def read_rows( attempt_timeout: float | None | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, retryable_errors: Sequence[type[Exception]] | TABLE_DEFAULT = TABLE_DEFAULT.READ_ROWS, + **kwargs, ) -> list[Row]: """ Read a set of rows from the table, based on the specified query. @@ -650,15 +662,16 @@ async def read_rows( - a list of Rows returned by the query Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ row_generator = await self.read_rows_stream( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, retryable_errors=retryable_errors, + **kwargs, ) return [row async for row in row_generator] @@ -697,17 +710,21 @@ async def read_row( - a Row object if the row exists, otherwise None Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=False + ) query = ReadRowsQuery(row_keys=row_key, row_filter=row_filter, limit=1) results = await self.read_rows( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metric_operation=metric_operation, retryable_errors=retryable_errors, ) if len(results) == 0: @@ -770,8 +787,8 @@ async def read_rows_sharded( for i in range(0, len(sharded_query), _CONCURRENCY_LIMIT) ] # run batches and collect results - results_list = [] - error_dict = {} + results_list: list[Row] = [] + error_dict: dict[int, Exception] = {} shard_idx = 0 for batch in batched_queries: batch_operation_timeout = next(timeout_generator) @@ -838,9 +855,9 @@ async def row_exists( - a bool indicating whether the row exists Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ if row_key is None: raise ValueError("row_key must be string or bytes") @@ -849,10 +866,14 @@ async def row_exists( limit_filter = CellsRowLimitFilter(1) chain_filter = RowFilterChain(filters=[limit_filter, strip_filter]) query = ReadRowsQuery(row_keys=row_key, limit=1, row_filter=chain_filter) + metric_operation = self._metrics.create_operation( + OperationType.READ_ROWS, is_streaming=False + ) results = await self.read_rows( query, operation_timeout=operation_timeout, attempt_timeout=attempt_timeout, + metric_operation=metric_operation, retryable_errors=retryable_errors, ) return len(results) > 0 @@ -894,9 +915,9 @@ async def sample_row_keys( - a set of RowKeySamples the delimit contiguous sections of the table Raises: - DeadlineExceeded: raised after operation timeout - will be chained with a RetryExceptionGroup containing GoogleAPIError exceptions + will be chained with a RetryExceptionGroup containing GoogleAPICallError exceptions from any retries that failed - - GoogleAPIError: raised if the request encounters an unrecoverable error + - GoogleAPICallError: raised if the request encounters an unrecoverable error """ # prepare timeouts operation_timeout, attempt_timeout = _get_timeouts( @@ -909,28 +930,43 @@ async def sample_row_keys( retryable_excs = _get_retryable_errors(retryable_errors, self) predicate = retries.if_exception_type(*retryable_excs) - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) + sleep_generator = backoff_generator() # prepare request metadata = _make_metadata(self.table_name, self.app_profile_id) - async def execute_rpc(): - results = await self.client._gapic_client.sample_row_keys( - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=next(attempt_timeout_gen), - metadata=metadata, - retry=None, - ) - return [(s.row_key, s.offset_bytes) async for s in results] + # wrap rpc in retry and metric collection logic + async with self._metrics.create_operation( + OperationType.SAMPLE_ROW_KEYS, backoff_generator=sleep_generator + ) as operation: + + async def execute_rpc(): + stream = await self.client._gapic_client.sample_row_keys( + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=next(attempt_timeout_gen), + metadata=metadata, + retry=None, + ) + samples = [(s.row_key, s.offset_bytes) async for s in stream] + # send metadata to metric collector + call_metadata = ( + await stream.trailing_metadata() + await stream.initial_metadata() + ) + operation.add_response_metadata(call_metadata) + # return results + return samples - return await retries.retry_target_async( - execute_rpc, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + metric_wrapped = operation.wrap_attempt_fn( + execute_rpc, extract_call_metadata=False + ) + return await retries.retry_target_async( + metric_wrapped, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) def mutations_batcher( self, @@ -1023,8 +1059,8 @@ async def mutate_row( Raises: - DeadlineExceeded: raised after operation timeout will be chained with a RetryExceptionGroup containing all - GoogleAPIError exceptions from any retries that failed - - GoogleAPIError: raised on non-idempotent operations that cannot be + GoogleAPICallError exceptions from any retries that failed + - GoogleAPICallError: raised on non-idempotent operations that cannot be safely retried. - ValueError if invalid arguments are provided """ @@ -1045,25 +1081,34 @@ async def mutate_row( # mutations should not be retried predicate = retries.if_exception_type() - sleep_generator = retries.exponential_sleep_generator(0.01, 2, 60) - - target = partial( - self.client._gapic_client.mutate_row, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - mutations=[mutation._to_pb() for mutation in mutations_list], - table_name=self.table_name, - app_profile_id=self.app_profile_id, - timeout=attempt_timeout, - metadata=_make_metadata(self.table_name, self.app_profile_id), - retry=None, - ) - return await retries.retry_target_async( - target, - predicate, - sleep_generator, - operation_timeout, - exception_factory=_retry_exception_factory, - ) + sleep_generator = backoff_generator() + + # wrap rpc in retry and metric collection logic + async with self._metrics.create_operation( + OperationType.MUTATE_ROW, backoff_generator=sleep_generator + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.mutate_row + ) + target = partial( + metric_wrapped, + row_key=row_key.encode("utf-8") + if isinstance(row_key, str) + else row_key, + mutations=[mutation._to_pb() for mutation in mutations_list], + table_name=self.table_name, + app_profile_id=self.app_profile_id, + timeout=attempt_timeout, + metadata=_make_metadata(self.table_name, self.app_profile_id), + retry=None, + ) + return await retries.retry_target_async( + target, + predicate, + sleep_generator, + operation_timeout, + exception_factory=_retry_exception_factory, + ) async def bulk_mutate_rows( self, @@ -1119,6 +1164,7 @@ async def bulk_mutate_rows( mutation_entries, operation_timeout, attempt_timeout, + self._metrics.create_operation(OperationType.BULK_MUTATE_ROWS), retryable_exceptions=retryable_excs, ) await operation.start() @@ -1165,7 +1211,7 @@ async def check_and_mutate_row( Returns: - bool indicating whether the predicate was true or false Raises: - - GoogleAPIError exceptions from grpc call + - GoogleAPICallError exceptions from grpc call """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) if true_case_mutations is not None and not isinstance( @@ -1179,18 +1225,25 @@ async def check_and_mutate_row( false_case_mutations = [false_case_mutations] false_case_list = [m._to_pb() for m in false_case_mutations or []] metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.check_and_mutate_row( - true_mutations=true_case_list, - false_mutations=false_case_list, - predicate_filter=predicate._to_pb() if predicate is not None else None, - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - return result.predicate_matched + + async with self._metrics.create_operation( + OperationType.CHECK_AND_MUTATE + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.check_and_mutate_row + ) + result = await metric_wrapped( + true_mutations=true_case_list, + false_mutations=false_case_list, + predicate_filter=predicate._to_pb() if predicate is not None else None, + row_key=row_key.encode() if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + return result.predicate_matched async def read_modify_write_row( self, @@ -1223,7 +1276,7 @@ async def read_modify_write_row( - Row: containing cell data that was modified as part of the operation Raises: - - GoogleAPIError exceptions from grpc call + - GoogleAPICallError exceptions from grpc call - ValueError if invalid arguments are provided """ operation_timeout, _ = _get_timeouts(operation_timeout, None, self) @@ -1234,17 +1287,25 @@ async def read_modify_write_row( if not rules: raise ValueError("rules must contain at least one item") metadata = _make_metadata(self.table_name, self.app_profile_id) - result = await self.client._gapic_client.read_modify_write_row( - rules=[rule._to_pb() for rule in rules], - row_key=row_key.encode("utf-8") if isinstance(row_key, str) else row_key, - table_name=self.table_name, - app_profile_id=self.app_profile_id, - metadata=metadata, - timeout=operation_timeout, - retry=None, - ) - # construct Row from result - return Row._from_pb(result.row) + + async with self._metrics.create_operation( + OperationType.READ_MODIFY_WRITE + ) as operation: + metric_wrapped = operation.wrap_attempt_fn( + self.client._gapic_client.read_modify_write_row + ) + + result = await metric_wrapped( + rules=[rule._to_pb() for rule in rules], + row_key=row_key.encode() if isinstance(row_key, str) else row_key, + table_name=self.table_name, + app_profile_id=self.app_profile_id, + metadata=metadata, + timeout=operation_timeout, + retry=None, + ) + # construct Row from result + return Row._from_pb(result.row) async def close(self): """ diff --git a/google/cloud/bigtable/data/_async/mutations_batcher.py b/google/cloud/bigtable/data/_async/mutations_batcher.py index 5d5dd535e..dbf3102a9 100644 --- a/google/cloud/bigtable/data/_async/mutations_batcher.py +++ b/google/cloud/bigtable/data/_async/mutations_batcher.py @@ -17,6 +17,7 @@ from typing import Any, Sequence, TYPE_CHECKING import asyncio import atexit +import time import warnings from collections import deque @@ -32,6 +33,8 @@ _MUTATE_ROWS_REQUEST_MUTATION_LIMIT, ) from google.cloud.bigtable.data.mutations import Mutation +from google.cloud.bigtable.data._metrics import OperationType +from google.cloud.bigtable.data._metrics import ActiveOperationMetric if TYPE_CHECKING: from google.cloud.bigtable.data._async.client import TableAsync @@ -328,9 +331,18 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): """ # flush new entries in_process_requests: list[asyncio.Future[list[FailedMutationEntryError]]] = [] + metric = self._table._metrics.create_operation(OperationType.BULK_MUTATE_ROWS) + flow_start_time = time.monotonic() async for batch in self._flow_control.add_to_flow(new_entries): - batch_task = self._create_bg_task(self._execute_mutate_rows, batch) + # add time waiting on flow control to throttling metric + metric.flow_throttling_time = time.monotonic() - flow_start_time + batch_task = self._create_bg_task(self._execute_mutate_rows, batch, metric) in_process_requests.append(batch_task) + # start a new metric for next batch + metric = self._table._metrics.create_operation( + OperationType.BULK_MUTATE_ROWS + ) + flow_start_time = time.monotonic() # wait for all inflight requests to complete found_exceptions = await self._wait_for_batch_results(*in_process_requests) # update exception data to reflect any new errors @@ -338,7 +350,7 @@ async def _flush_internal(self, new_entries: list[RowMutationEntry]): self._add_exceptions(found_exceptions) async def _execute_mutate_rows( - self, batch: list[RowMutationEntry] + self, batch: list[RowMutationEntry], metrics: ActiveOperationMetric ) -> list[FailedMutationEntryError]: """ Helper to execute mutation operation on a batch @@ -358,6 +370,7 @@ async def _execute_mutate_rows( batch, operation_timeout=self._operation_timeout, attempt_timeout=self._attempt_timeout, + metrics=metrics, retryable_exceptions=self._retryable_errors, ) await operation.start() @@ -487,10 +500,10 @@ async def _wait_for_batch_results( found_errors = [] for result in all_results: if isinstance(result, Exception): - # will receive direct Exception objects if request task fails + # will receive Exception objects if request task fails. Add to list found_errors.append(result) elif isinstance(result, BaseException): - # BaseException not expected from grpc calls. Raise immediately + # BaseException won't be encountered in normal execution. Raise immediately raise result elif result: # completed requests will return a list of FailedMutationEntryError diff --git a/google/cloud/bigtable/data/_helpers.py b/google/cloud/bigtable/data/_helpers.py index 13e921c20..0060501ba 100644 --- a/google/cloud/bigtable/data/_helpers.py +++ b/google/cloud/bigtable/data/_helpers.py @@ -137,7 +137,7 @@ def _retry_exception_factory( timeout_val_str = f"of {timeout_val:0.1f}s " if timeout_val is not None else "" # if failed due to timeout, raise deadline exceeded as primary exception source_exc: Exception = core_exceptions.DeadlineExceeded( - f"operation_timeout{timeout_val_str} exceeded" + f"operation_timeout {timeout_val_str}exceeded" ) elif exc_list: # otherwise, raise non-retryable error as primary exception diff --git a/google/cloud/bigtable/data/_metrics/data_model.py b/google/cloud/bigtable/data/_metrics/data_model.py index fa7efdd37..87172b8c6 100644 --- a/google/cloud/bigtable/data/_metrics/data_model.py +++ b/google/cloud/bigtable/data/_metrics/data_model.py @@ -27,9 +27,13 @@ from dataclasses import field from grpc import StatusCode -import google.cloud.bigtable.data.exceptions as bt_exceptions +from google.cloud.bigtable.data.exceptions import FailedQueryShardError +from google.cloud.bigtable.data.exceptions import FailedMutationEntryError +from google.cloud.bigtable.data.exceptions import _BigtableExceptionGroup +from google.cloud.bigtable.data._helpers import _retry_exception_factory from google.cloud.bigtable_v2.types.response_params import ResponseParams from google.protobuf.message import DecodeError +from google.api_core.retry import RetryFailureReason if TYPE_CHECKING: from google.cloud.bigtable.data._metrics.handlers._base import MetricsHandler @@ -364,13 +368,17 @@ def end_with_success(self): """ return self.end_with_status(StatusCode.OK) - def build_wrapped_predicate( - self, inner_predicate: Callable[[Exception], bool] + def build_wrapped_fn_handlers( + self, + inner_predicate: Callable[[Exception], bool], ) -> Callable[[Exception], bool]: """ - Wrapps a predicate to include metrics tracking. Any call to the resulting predicate - is assumed to be an rpc failure, and will either mark the end of the active attempt - or the end of the operation. + One way to track metrics is by wrapping the `predicate` and `exception_factory` + arguments of `api_core.Retry`. This will notify us when an exception occurs so + we can track it. + + This function retruns wrapped versions of the `predicate` and `exception_factory` + to be passed down when building the `Retry` object. Args: - predicate: The predicate to wrap. @@ -384,7 +392,17 @@ def wrapped_predicate(exc: Exception) -> bool: self.end_with_status(exc) return inner_result - return wrapped_predicate + def wrapped_exception_factory( + exc_list: list[Exception], + reason: RetryFailureReason, + timeout_val: float | None, + ) -> tuple[Exception, Exception | None]: + exc, source = _retry_exception_factory(exc_list, reason, timeout_val) + if reason != RetryFailureReason.NON_RETRYABLE_ERROR: + self.end_with_status(exc) + return exc, source + + return wrapped_predicate, wrapped_exception_factory @staticmethod def _exc_to_status(exc: Exception) -> StatusCode: @@ -399,8 +417,14 @@ def _exc_to_status(exc: Exception) -> StatusCode: Args: - exc: The exception to extract the status code from. """ - if isinstance(exc, bt_exceptions._BigtableExceptionGroup): - exc = exc.exceptions[-1] + # parse bigtable custom exceptions + if isinstance(exc, _BigtableExceptionGroup) and exc.exceptions: + # find most recent in group + return ActiveOperationMetric._exc_to_status(exc.exceptions[-1]) + if isinstance(exc, (FailedMutationEntryError, FailedQueryShardError)): + # find cause of failed entries + return ActiveOperationMetric._exc_to_status(exc.__cause__) + # parse grpc exceptions if hasattr(exc, "grpc_status_code") and exc.grpc_status_code is not None: return exc.grpc_status_code if ( diff --git a/google/cloud/bigtable/data/_metrics/handlers/gcp_exporter.py b/google/cloud/bigtable/data/_metrics/handlers/gcp_exporter.py index 7f0e0365c..af513e92b 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/gcp_exporter.py +++ b/google/cloud/bigtable/data/_metrics/handlers/gcp_exporter.py @@ -53,9 +53,9 @@ # fmt: off MILLIS_AGGREGATION = view.ExplicitBucketHistogramAggregation( [ - 0, 0.01, 0.05, 0.1, 0.3, 0.6, 0.8, 1, 2, 3, 4, 5, 6, 8, 10, 13, 16, - 20, 25, 30, 40, 50, 65, 80, 100, 130, 160, 200, 250, 300, 400, - 500, 650, 800, 1000, 2000, 5000, 10000, 20000, 50000, 100000, + 0, 1, 2, 3, 4, 5, 6, 8, 10, 13, 16, 20, 25, 30, 40, 50, 65, 80, 100, + 130, 160, 200, 250, 300, 400, 500, 650, 800, 1000, 2000, 5000, 10_000, + 20_000, 50_000, 100_000, 200_000, 400_000, 800_000, 1_600_000, 3_200_000 ] ) # fmt: on @@ -81,6 +81,9 @@ for n in INSTRUMENT_NAMES ] +# We use the minimum value supported by the Cloud Monitoring API (60 seconds) +EXPORT_INTERVAL_MS = 60_000 + class GoogleCloudMetricsHandler(OpenTelemetryMetricsHandler): """ @@ -102,12 +105,12 @@ class GoogleCloudMetricsHandler(OpenTelemetryMetricsHandler): - export_interval: The interval (in seconds) at which to export metrics to Cloud Monitoring. """ - def __init__(self, *args, project_id: str, export_interval=60, **kwargs): + def __init__(self, *args, project_id: str, **kwargs): # internal exporter to write metrics to Cloud Monitoring exporter = _BigtableMetricsExporter(project_id=project_id) # periodically executes exporter gcp_reader = PeriodicExportingMetricReader( - exporter, export_interval_millis=export_interval * 1000 + exporter, export_interval_millis=EXPORT_INTERVAL_MS ) # use private meter provider to store instruments and views meter_provider = MeterProvider(metric_readers=[gcp_reader], views=VIEW_LIST) diff --git a/google/cloud/bigtable/data/_metrics/handlers/opentelemetry.py b/google/cloud/bigtable/data/_metrics/handlers/opentelemetry.py index 9f4f2caf3..5dd6949d5 100644 --- a/google/cloud/bigtable/data/_metrics/handlers/opentelemetry.py +++ b/google/cloud/bigtable/data/_metrics/handlers/opentelemetry.py @@ -178,8 +178,9 @@ def on_operation_complete(self, op: CompletedOperationMetric) -> None: op.duration_ms, {"streaming": is_streaming, **labels} ) # only record completed attempts if there were retries - if op.completed_attempts: - self.otel.retry_count.add(len(op.completed_attempts) - 1, labels) + num_attempts = len(op.completed_attempts) + if num_attempts > 1: + self.otel.retry_count.add(num_attempts - 1, labels) def on_attempt_complete( self, attempt: CompletedAttemptMetric, op: ActiveOperationMetric @@ -211,7 +212,8 @@ def on_attempt_complete( combined_throttling += op.flow_throttling_time_ms self.otel.throttling_latencies.record(combined_throttling, labels) self.otel.application_latencies.record( - attempt.application_blocking_time_ms + attempt.backoff_before_attempt_ms, labels + attempt.application_blocking_time_ms + attempt.backoff_before_attempt_ms, + labels, ) if ( op.op_type == OperationType.READ_ROWS @@ -228,6 +230,4 @@ def on_attempt_complete( else: # gfe headers not attached. Record a connectivity error. # TODO: this should not be recorded as an error when direct path is enabled - self.otel.connectivity_error_count.add( - 1, {"status": status, **labels} - ) + self.otel.connectivity_error_count.add(1, {"status": status, **labels}) diff --git a/google/cloud/bigtable_v2/services/bigtable/async_client.py b/google/cloud/bigtable_v2/services/bigtable/async_client.py index 0421e19bc..48d19e327 100644 --- a/google/cloud/bigtable_v2/services/bigtable/async_client.py +++ b/google/cloud/bigtable_v2/services/bigtable/async_client.py @@ -36,6 +36,7 @@ from google.api_core.client_options import ClientOptions from google.api_core import exceptions as core_exceptions from google.api_core import gapic_v1 +from google.api_core.grpc_helpers_async import GrpcAsyncStream from google.api_core import retry_async as retries from google.auth import credentials as ga_credentials # type: ignore from google.oauth2 import service_account # type: ignore @@ -260,7 +261,7 @@ def read_rows( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.ReadRowsResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.ReadRowsResponse]]: r"""Streams back the contents of all requested rows in key order, optionally applying the same Reader filter to each. Depending on their size, rows and cells may be @@ -357,7 +358,7 @@ def sample_row_keys( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.SampleRowKeysResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.SampleRowKeysResponse]]: r"""Returns a sample of row keys in the table. The returned row keys will delimit contiguous sections of the table of approximately equal size, which can be used @@ -444,7 +445,7 @@ def sample_row_keys( # Done; return the response. return response - async def mutate_row( + def mutate_row( self, request: Optional[Union[bigtable.MutateRowRequest, dict]] = None, *, @@ -551,17 +552,14 @@ async def mutate_row( # Validate the universe domain. self._client._validate_universe_domain() - # Send the request. - response = await rpc( + # Return the grpc call coroutine + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def mutate_rows( self, request: Optional[Union[bigtable.MutateRowsRequest, dict]] = None, @@ -572,7 +570,7 @@ def mutate_rows( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.MutateRowsResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.MutateRowsResponse]]: r"""Mutates multiple rows in a batch. Each individual row is mutated atomically as in MutateRow, but the entire batch is not executed atomically. @@ -674,7 +672,7 @@ def mutate_rows( # Done; return the response. return response - async def check_and_mutate_row( + def check_and_mutate_row( self, request: Optional[Union[bigtable.CheckAndMutateRowRequest, dict]] = None, *, @@ -819,18 +817,15 @@ async def check_and_mutate_row( # Validate the universe domain. self._client._validate_universe_domain() - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - - async def ping_and_warm( + def ping_and_warm( self, request: Optional[Union[bigtable.PingAndWarmRequest, dict]] = None, *, @@ -913,18 +908,15 @@ async def ping_and_warm( # Validate the universe domain. self._client._validate_universe_domain() - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - - async def read_modify_write_row( + def read_modify_write_row( self, request: Optional[Union[bigtable.ReadModifyWriteRowRequest, dict]] = None, *, @@ -1038,17 +1030,14 @@ async def read_modify_write_row( # Validate the universe domain. self._client._validate_universe_domain() - # Send the request. - response = await rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def generate_initial_change_stream_partitions( self, request: Optional[ @@ -1061,7 +1050,7 @@ def generate_initial_change_stream_partitions( timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), ) -> Awaitable[ - AsyncIterable[bigtable.GenerateInitialChangeStreamPartitionsResponse] + GrpcAsyncStream[bigtable.GenerateInitialChangeStreamPartitionsResponse] ]: r"""NOTE: This API is intended to be used by Apache Beam BigtableIO. Returns the current list of partitions that make up the table's @@ -1148,17 +1137,14 @@ def generate_initial_change_stream_partitions( # Validate the universe domain. self._client._validate_universe_domain() - # Send the request. - response = rpc( + # Return the grpc call coroutine. + return rpc( request, retry=retry, timeout=timeout, metadata=metadata, ) - # Done; return the response. - return response - def read_change_stream( self, request: Optional[Union[bigtable.ReadChangeStreamRequest, dict]] = None, @@ -1168,7 +1154,7 @@ def read_change_stream( retry: OptionalRetry = gapic_v1.method.DEFAULT, timeout: Union[float, object] = gapic_v1.method.DEFAULT, metadata: Sequence[Tuple[str, str]] = (), - ) -> Awaitable[AsyncIterable[bigtable.ReadChangeStreamResponse]]: + ) -> Awaitable[GrpcAsyncStream[bigtable.ReadChangeStreamResponse]]: r"""NOTE: This API is intended to be used by Apache Beam BigtableIO. Reads changes from a table's change stream. Changes will reflect both user-initiated mutations and diff --git a/mypy.ini b/mypy.ini index 31cc24223..1c755808e 100644 --- a/mypy.ini +++ b/mypy.ini @@ -26,3 +26,6 @@ ignore_missing_imports = True [mypy-pytest] ignore_missing_imports = True + +[mypy-google.api.*] +ignore_missing_imports = True diff --git a/tests/system/data/setup_fixtures.py b/tests/system/data/setup_fixtures.py index 77086b7f3..c7dda51e7 100644 --- a/tests/system/data/setup_fixtures.py +++ b/tests/system/data/setup_fixtures.py @@ -77,10 +77,12 @@ def instance_id(admin_client, project_id, cluster_config): operation.result(timeout=240) except exceptions.AlreadyExists: pass - yield instance_id - admin_client.instance_admin_client.delete_instance( - name=f"projects/{project_id}/instances/{instance_id}" - ) + try: + yield instance_id + finally: + admin_client.instance_admin_client.delete_instance( + name=f"projects/{project_id}/instances/{instance_id}" + ) @pytest.fixture(scope="session") diff --git a/tests/system/data/test_metrics.py b/tests/system/data/test_metrics.py new file mode 100644 index 000000000..56228361f --- /dev/null +++ b/tests/system/data/test_metrics.py @@ -0,0 +1,640 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import pytest +import pytest_asyncio +import mock +from functools import partial + +from google.api_core.exceptions import NotFound, ServiceUnavailable, DeadlineExceeded +from google.cloud.bigtable.data._metrics.data_model import OperationType +from google.cloud.bigtable.data import MutationsExceptionGroup +from google.api_core import retry + +from google.cloud.bigtable.data import SetCell, ReadRowsQuery, RowMutationEntry +from google.cloud.bigtable.data.read_modify_write_rules import IncrementRule +from google.cloud.bigtable.data.row_filters import PassAllFilter +from google.cloud.bigtable_v2 import BigtableAsyncClient + +from .test_system import ( + TEST_FAMILY, + TEST_CLUSTER, + TEST_ZONE, + cluster_config, # noqa: F401 + column_family_config, # noqa: F401 + init_table_id, # noqa: F401 +) + +# use this value to make sure we are testing against a consistent test of metrics +# if _populate_calls is modified, this value will have to change +EXPECTED_METRIC_COUNT = 172 + +ALL_INSTRUMENTS = [ + "operation_latencies", + "attempt_latencies", + "server_latencies", + "first_response_latencies", + "application_blocking_latencies", + "client_blocking_latencies", + "retry_count", + "connectivity_error_count", +] + + +async def _populate_calls(client, instance_id, table_id): + """ + Call rpcs to populate the backend with a set of metrics to run tests against: + - succcessful rpcs + - rpcs that raise terminal NotFound error + - rpcs that raise retryable ServiceUnavailable error and then succeed + - rpcs that raise retryable ServiceUnavailable error and then timeout + - rpcs that raise retryable ServiceUnavailable error and then raise terminal NotFound error + + Each set of calls is made with a unique app profile to make assertions about each easier + This app profile is not set on the backend, but only passed to the metric handler + """ + + # helper function to build separate table instances for each app profile + def table_with_profile(app_profile_id): + from google.cloud.bigtable.data._metrics import GoogleCloudMetricsHandler + + table = client.get_table(instance_id, table_id) + kwargs = { + "project_id": client.project, + "instance_id": instance_id, + "table_id": table_id, + "app_profile_id": app_profile_id, # use artificial app profile for metrics, not in backend + } + # create a GCP exporter with 1 second interval + with mock.patch("google.cloud.bigtable.data._metrics.handlers.gcp_exporter.EXPORT_INTERVAL_MS", 1000): + table._metrics.handlers = [ + GoogleCloudMetricsHandler(**kwargs), + ] + return table + + # helper function that builds callables to execute each type of rpc + def _get_stubs_for_table(table): + family = TEST_FAMILY + qualifier = b"qualifier" + init_cell = SetCell(family=family, qualifier=qualifier, new_value=0) + rpc_stubs = { + "mutate_row": [partial(table.mutate_row, b"row1", init_cell)], + "mutate_rows": [ + partial( + table.bulk_mutate_rows, [RowMutationEntry(b"row2", [init_cell])] + ) + ], + "read_rows": [ + partial(table.read_row, b"row1"), + partial(table.read_rows, ReadRowsQuery(row_keys=[b"row1", b"row2"])), + ], + "sample_row_keys": [table.sample_row_keys], + "read_modify_write_row": [ + partial( + table.read_modify_write_row, + b"row1", + IncrementRule(family, qualifier, 1), + ) + ], + "check_and_mutate_row": [ + partial( + table.check_and_mutate_row, + b"row1", + PassAllFilter(True), + true_case_mutations=[init_cell], + ) + ], + } + return rpc_stubs + + # Call each rpc with no errors. Should be successful. + print("populating successful rpcs...") + async with table_with_profile("success") as table: + stubs = _get_stubs_for_table(table) + for stub_list in stubs.values(): + for stub in stub_list: + await stub() + + # Call each rpc with a terminal exception. Does not hit gcp servers + print("populating terminal NotFound rpcs...") + async with table_with_profile("terminal_exception") as table: + stubs = _get_stubs_for_table(table) + for rpc_name, stub_list in stubs.items(): + for stub in stub_list: + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{rpc_name}", + side_effect=NotFound("test"), + ): + with pytest.raises((NotFound, MutationsExceptionGroup)): + await stub() + + non_retryable_rpcs = ["read_modify_write_row", "check_and_mutate_row"] + # Calls hit retryable errors, then succeed + print("populating retryable success rpcs...") + async with table_with_profile("retry_then_success") as table: + stubs = { + k: v + for k, v in _get_stubs_for_table(table).items() + if k not in non_retryable_rpcs + } + for rpc_name, stub_list in stubs.items(): + for stub in stub_list: + true_fn = BigtableAsyncClient.__dict__[rpc_name] + counter = 0 + + # raise errors twice, then call true function + def side_effect(*args, **kwargs): + nonlocal counter + nonlocal true_fn + if counter < 2: + counter += 1 + raise ServiceUnavailable("test") + return true_fn(table.client._gapic_client, *args, **kwargs) + + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{rpc_name}", + side_effect=side_effect, + ): + await stub(retryable_errors=(ServiceUnavailable,)) + + # Calls hit retryable errors, then hit deadline + # should have 2 attempts, through mocked backoff + print("populating retryable timeout rpcs...") + async with table_with_profile("retry_then_timeout") as table: + stubs = { + k: v + for k, v in _get_stubs_for_table(table).items() + if k not in non_retryable_rpcs + } + for rpc_name, stub_list in stubs.items(): + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{rpc_name}", + side_effect=ServiceUnavailable("test"), + ): + for stub in stub_list: + with mock.patch( + "google.cloud.bigtable.data._helpers.exponential_sleep_generator", + return_value=iter([0.01, 0.01, 5]), + ): + with pytest.raises((DeadlineExceeded, MutationsExceptionGroup)): + await stub( + operation_timeout=0.5, + retryable_errors=(ServiceUnavailable,), + ) + + # Calls hit retryable errors, then hit terminal exception + print("populating retryable then terminal error rpcs...") + async with table_with_profile("retry_then_terminal") as table: + stubs = { + k: v + for k, v in _get_stubs_for_table(table).items() + if k not in non_retryable_rpcs + } + for rpc_name, stub_list in stubs.items(): + for stub in stub_list: + error_list = [ServiceUnavailable("test")] * 2 + [NotFound("test")] + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{rpc_name}", + side_effect=error_list, + ): + with pytest.raises((NotFound, MutationsExceptionGroup)): + await stub(retryable_errors=(ServiceUnavailable,)) + + +@pytest_asyncio.fixture(scope="session") +async def get_all_metrics(client, instance_id, table_id, project_id): + from google.protobuf.timestamp_pb2 import Timestamp + from google.cloud.monitoring_v3.types.common import TimeInterval + + # populate table with metrics + start_time = Timestamp() + start_time.GetCurrentTime() + await _populate_calls(client, instance_id, table_id) + + # read them down and save to a list. Retry until ready + @retry.Retry(predicate=retry.if_exception_type(NotFound), maximum=5, timeout=2 * 60) + def _read_metrics(): + from google.cloud.monitoring_v3 import MetricServiceClient + + all_responses = [] + client = MetricServiceClient() + end_time = Timestamp() + end_time.GetCurrentTime() + + for instrument in ALL_INSTRUMENTS: + response = client.list_time_series( + name=f"projects/{project_id}", + filter=f'metric.type="bigtable.googleapis.com/client/{instrument}" AND resource.labels.instance="{instance_id}" AND resource.labels.table="{table_id}"', + interval=TimeInterval(start_time=start_time, end_time=end_time), + ) + response = list(response) + if not response: + print(f"no data for {instrument}") + raise NotFound("No metrics found") + all_responses.extend(response) + if len(all_responses) < EXPECTED_METRIC_COUNT: + print( + f"Only found {len(all_responses)} of {EXPECTED_METRIC_COUNT} metrics. Retrying..." + ) + raise NotFound("Not all metrics found") + elif len(all_responses) > EXPECTED_METRIC_COUNT: + raise ValueError(f"Found more metrics than expected: {len(all_responses)}") + return all_responses + + print("waiting for metrics to be ready...") + metrics = _read_metrics() + print("metrics ready") + return metrics + + +@pytest.mark.asyncio +async def test_resource(get_all_metrics, instance_id, table_id, project_id): + """ + all metrics should have monitored resource populated consistently + """ + for m in get_all_metrics: + resource = m.resource + assert resource.type == "bigtable_table" + assert len(resource.labels) == 5 + assert resource.labels["instance"] == instance_id + assert resource.labels["table"] == table_id + assert resource.labels["project_id"] == project_id + # zone and cluster should use default values for failed attempts + assert resource.labels["zone"] in [TEST_ZONE, "global"] + assert resource.labels["cluster"] in [TEST_CLUSTER, "unspecified"] + + +@pytest.mark.asyncio +async def test_client_name(get_all_metrics): + """ + all metrics should have client_name populated consistently + """ + from google.cloud.bigtable import __version__ + + for m in get_all_metrics: + client_name = m.metric.labels["client_name"] + assert client_name == "python-bigtable/" + __version__ + + +@pytest.mark.asyncio +async def test_app_profile(get_all_metrics): + """ + all metrics should have app_profile populated with one of the test values + """ + supported_app_profiles = [ + "success", + "terminal_exception", + "retry_then_success", + "retry_then_timeout", + "retry_then_terminal", + ] + for m in get_all_metrics: + app_profile = m.metric.labels["app_profile"] + assert app_profile in supported_app_profiles + + +@pytest.mark.asyncio +async def test_latency_data_types(get_all_metrics): + """ + all latency metrics should have metric_kind DELTA and value_type DISTRIBUTION + """ + latency_metrics = [m for m in get_all_metrics if "latencies" in m.metric.type] + # ensure we got all metrics + assert len(latency_metrics) > 100 + assert any("operation_latencies" in m.metric.type for m in latency_metrics) + assert any("attempt_latencies" in m.metric.type for m in latency_metrics) + assert any("server_latencies" in m.metric.type for m in latency_metrics) + assert any("first_response_latencies" in m.metric.type for m in latency_metrics) + assert any( + "application_blocking_latencies" in m.metric.type for m in latency_metrics + ) + assert any("client_blocking_latencies" in m.metric.type for m in latency_metrics) + # ensure all types are correct + for m in latency_metrics: + assert m.metric_kind == 2 # DELTA + assert m.value_type == 5 # DISTRIBUTION + + +@pytest.mark.asyncio +async def test_count_data_types(get_all_metrics): + """ + all count metrics should have metric_kind DELTA and value_type INT64 + """ + count_metrics = [m for m in get_all_metrics if "count" in m.metric.type] + # ensure we got all metrics + assert len(count_metrics) > 25 + assert any("retry_count" in m.metric.type for m in count_metrics) + assert any("connectivity_error_count" in m.metric.type for m in count_metrics) + # ensure all types are correct + for m in count_metrics: + assert m.metric_kind == 2 # DELTA + assert m.value_type == 2 # INT64 + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "instrument,methods", + [ + ("operation_latencies", list(OperationType)), # all operation types + ("attempt_latencies", list(OperationType)), + ("server_latencies", list(OperationType)), + ("application_blocking_latencies", list(OperationType)), + ("client_blocking_latencies", list(OperationType)), + ( + "first_response_latencies", + [OperationType.READ_ROWS], + ), # only valid for ReadRows + ("connectivity_error_count", list(OperationType)), + ( + "retry_count", + [ + OperationType.READ_ROWS, + OperationType.SAMPLE_ROW_KEYS, + OperationType.BULK_MUTATE_ROWS, + OperationType.MUTATE_ROW, + ], + ), # only valid for retryable operations + ], +) +async def test_full_method_coverage(get_all_metrics, instrument, methods): + """ + ensure that each instrument type has data for all expected rpc methods + """ + filtered_metrics = [m for m in get_all_metrics if instrument in m.metric.type] + assert len(filtered_metrics) > 0 + # ensure all methods are covered + for method in methods: + assert any( + method.value in m.metric.labels["method"] for m in filtered_metrics + ), f"{method} not found in {instrument}" + # ensure no unexpected methods are covered + for m in filtered_metrics: + assert m.metric.labels["method"] in [ + m.value for m in methods + ], f"unexpected method {m.metric.labels['method']}" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "instrument,include_status,include_streaming", + [ + ("operation_latencies", True, True), + ("attempt_latencies", True, True), + ("server_latencies", True, True), + ("first_response_latencies", True, False), + ("connectivity_error_count", True, False), + ("retry_count", True, False), + ("application_blocking_latencies", False, False), + ("client_blocking_latencies", False, False), + ], +) +async def test_labels(get_all_metrics, instrument, include_status, include_streaming): + """ + all metrics have 3 common labels: method, client_name, app_profile + + some metrics also have status and streaming labels + """ + assert len(get_all_metrics) > 0 + filtered_metrics = [m for m in get_all_metrics if instrument in m.metric.type] + expected_num = 3 + int(include_status) + int(include_streaming) + for m in filtered_metrics: + labels = m.metric.labels + # check for count + assert len(labels) == expected_num + # check for common labels + assert "client_name" in labels + assert "method" in labels + assert "app_profile" in labels + # check for optional labels + if include_status: + assert "status" in labels + if include_streaming: + assert "streaming" in labels + + +@pytest.mark.asyncio +async def test_streaming_label(get_all_metrics): + """ + streaming=True indicates point-reads using ReadRows rpc + + We should only set it set on ReadRows, and we should see a mix of True and False in + the dataset + """ + # find set of metrics that support streaming tag + streaming_instruments = [ + "operation_latencies", + "attempt_latencies", + "server_latencies", + ] + streaming_metrics = [ + m + for m in get_all_metrics + if any(i in m.metric.type for i in streaming_instruments) + ] + non_read_rows = [ + m + for m in streaming_metrics + if m.metric.labels["method"] != OperationType.READ_ROWS.value + ] + assert len(non_read_rows) > 50 + # ensure all non-read-rows have streaming=False + assert all(m.metric.labels["streaming"] == "false" for m in non_read_rows) + # ensure read-rows have a mix of True and False, for each instrument + for instrument in streaming_instruments: + filtered_read_rows = [ + m + for m in streaming_metrics + if instrument in m.metric.type + and m.metric.labels["method"] == OperationType.READ_ROWS.value + ] + assert len(filtered_read_rows) > 0 + assert any(m.metric.labels["streaming"] == "true" for m in filtered_read_rows) + assert any(m.metric.labels["streaming"] == "false" for m in filtered_read_rows) + + +@pytest.mark.asyncio +async def test_status_success(get_all_metrics): + """ + check the subset of successful rpcs + + They should have no retries, no connectivity errors, a status of OK + Should have cluster and zone properly set + """ + success_metrics = [ + m for m in get_all_metrics if m.metric.labels["app_profile"] == "success" + ] + # ensure each expected instrument is present in data + assert any("operation_latencies" in m.metric.type for m in success_metrics) + assert any("attempt_latencies" in m.metric.type for m in success_metrics) + assert any("server_latencies" in m.metric.type for m in success_metrics) + assert any("first_response_latencies" in m.metric.type for m in success_metrics) + assert any( + "application_blocking_latencies" in m.metric.type for m in success_metrics + ) + assert any("client_blocking_latencies" in m.metric.type for m in success_metrics) + for m in success_metrics: + # ensure no retries or connectivity errors recorded + assert "connectivity_error_count" not in m.metric.type + assert "retry_count" not in m.metric.type + # if instrument has status label, should be OK + if "status" in m.metric.labels: + assert m.metric.labels["status"] == "OK" + # check for cluster and zone + assert m.resource.labels["zone"] == TEST_ZONE + assert m.resource.labels["cluster"] == TEST_CLUSTER + + +@pytest.mark.asyncio +async def test_status_exception(get_all_metrics): + """ + check the subset of rpcs with a single terminal exception + + They should have no retries, 1+ connectivity errors, a status of NOT_FOUND + Should have default values for cluster and zone + """ + fail_metrics = [ + m + for m in get_all_metrics + if m.metric.labels["app_profile"] == "terminal_exception" + ] + # ensure each expected instrument is present in data + assert any("operation_latencies" in m.metric.type for m in fail_metrics) + assert any("attempt_latencies" in m.metric.type for m in fail_metrics) + assert any("application_blocking_latencies" in m.metric.type for m in fail_metrics) + assert any("client_blocking_latencies" in m.metric.type for m in fail_metrics) + assert any("connectivity_error_count" in m.metric.type for m in fail_metrics) + # server_latencies, first_response_latencies and retry_count are not expected + assert not any("server_latencies" in m.metric.type for m in fail_metrics) + assert not any("retry_count" in m.metric.type for m in fail_metrics) + assert not any("first_response_latencies" in m.metric.type for m in fail_metrics) + for m in fail_metrics: + # if instrument has status label, should be UNAVAILABLE + if "status" in m.metric.labels: + assert m.metric.labels["status"] == "NOT_FOUND" + # check for cluster and zone + assert m.resource.labels["zone"] == "global" + assert m.resource.labels["cluster"] == "unspecified" + # each rpc should have at least one connectivity error + # ReadRows will have more, since we test point reads and streams + connectivity_error_counts = [ + m for m in fail_metrics if "connectivity_error_count" in m.metric.type + ] + for error_metric in connectivity_error_counts: + total_points = sum([int(pt.value.int64_value) for pt in error_metric.points]) + assert total_points >= 1 + # ensure each rpc reported connectivity errors + for prc in OperationType: + assert any( + m.metric.labels["method"] == prc.value for m in connectivity_error_counts + ) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + "app_profile,final_status", + [ + ("retry_then_success", "OK"), + ("retry_then_terminal", "NOT_FOUND"), + ("retry_then_timeout", "DEADLINE_EXCEEDED"), + ], +) +async def test_status_retry(get_all_metrics, app_profile, final_status): + """ + check the subset of calls that fail and retry + + All retry metrics should have 2 attempts before reaching final status + + Should have retries, connectivity_errors, and status of `final_status`. + cluster and zone may or may not change from the default value, depending on the final status + """ + # get set of all retry_then_success metrics + retry_metrics = [ + m for m in get_all_metrics if m.metric.labels["app_profile"] == app_profile + ] + # find each relevant instrument + retry_counts = [m for m in retry_metrics if "retry_count" in m.metric.type] + assert len(retry_counts) > 0 + connectivity_error_counts = [ + m for m in retry_metrics if "connectivity_error_count" in m.metric.type + ] + assert len(connectivity_error_counts) > 0 + operation_latencies = [ + m for m in retry_metrics if "operation_latencies" in m.metric.type + ] + assert len(operation_latencies) > 0 + attempt_latencies = [ + m for m in retry_metrics if "attempt_latencies" in m.metric.type + ] + assert len(attempt_latencies) > 0 + server_latencies = [ + m for m in retry_metrics if "server_latencies" in m.metric.type + ] # may not be present. only if reached server + first_response_latencies = [ + m for m in retry_metrics if "first_response_latencies" in m.metric.type + ] # may not be present + + # should have at least 2 retry attempts + # ReadRows will have more, because it is called multiple times in the test data + for m in retry_counts: + total_errors = sum([int(pt.value.int64_value) for pt in m.points]) + assert total_errors >= 2 + # each rpc should have at least one connectivity error + # most will have 2, but will have 1 if status == NOT_FOUND + for m in connectivity_error_counts: + total_errors = sum([int(pt.value.int64_value) for pt in m.points]) + assert total_errors >= 1 + + # all operation-level status should be final_status + for m in operation_latencies + retry_counts: + assert m.metric.labels["status"] == final_status + + # check attempt statuses + attempt_statuses = set( + [ + m.metric.labels["status"] + for m in attempt_latencies + + server_latencies + + first_response_latencies + + connectivity_error_counts + ] + ) + if final_status == "DEADLINE_EXCEEDED": + # operation DEADLINE_EXCEEDED never shows up in attempts + assert len(attempt_statuses) == 1 + assert "UNAVAILABLE" in attempt_statuses + else: + # all other attempt-level status should have a mix of final_status and UNAVAILABLE + assert len(attempt_statuses) == 2 + assert "UNAVAILABLE" in attempt_statuses + assert final_status in attempt_statuses + + +@pytest.mark.asyncio +async def test_latency_metric_histogram_buckets(get_all_metrics): + """ + latency metrics should all have histogram buckets set up properly + """ + from google.cloud.bigtable.data._metrics.handlers.gcp_exporter import ( + MILLIS_AGGREGATION, + ) + + filtered = [m for m in get_all_metrics if "latency" in m.metric.type] + all_values = [pt.value.distribution_value for m in filtered for pt in m.points] + for v in all_values: + # check bucket schema + assert v.bucket_options.explicit_buckets.bounds == MILLIS_AGGREGATION.boundaries + # check for reasobble values + assert v.count > 0 + assert v.mean > 0 + assert v.mean < 5000 diff --git a/tests/system/data/test_system.py b/tests/system/data/test_system.py index aeb08fc1a..ca40f9f9a 100644 --- a/tests/system/data/test_system.py +++ b/tests/system/data/test_system.py @@ -25,6 +25,8 @@ TEST_FAMILY = "test-family" TEST_FAMILY_2 = "test-family-2" +TEST_ZONE = "us-central1-b" +TEST_CLUSTER = "test-cluster" @pytest.fixture(scope="session") @@ -53,8 +55,8 @@ def cluster_config(project_id): from google.cloud.bigtable_admin_v2 import types cluster = { - "test-cluster": types.Cluster( - location=f"projects/{project_id}/locations/us-central1-b", + TEST_CLUSTER: types.Cluster( + location=f"projects/{project_id}/locations/{TEST_ZONE}", serve_nodes=1, ) } diff --git a/tests/unit/data/_async/__init__.py b/tests/unit/data/_async/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/unit/data/_async/test__mutate_rows.py b/tests/unit/data/_async/test__mutate_rows.py index e03028c45..ffee14087 100644 --- a/tests/unit/data/_async/test__mutate_rows.py +++ b/tests/unit/data/_async/test__mutate_rows.py @@ -18,6 +18,8 @@ from google.rpc import status_pb2 import google.api_core.exceptions as core_exceptions +from .test_client import mock_grpc_call + # try/except added for compatibility with python < 3.8 try: from unittest import mock @@ -48,6 +50,7 @@ def _make_one(self, *args, **kwargs): kwargs["table"] = kwargs.pop("table", AsyncMock()) kwargs["operation_timeout"] = kwargs.pop("operation_timeout", 5) kwargs["attempt_timeout"] = kwargs.pop("attempt_timeout", 0.1) + kwargs["metrics"] = kwargs.pop("metrics", mock.Mock()) kwargs["retryable_exceptions"] = kwargs.pop("retryable_exceptions", ()) kwargs["mutation_entries"] = kwargs.pop("mutation_entries", []) return self._target_class()(*args, **kwargs) @@ -67,9 +70,17 @@ def _make_mock_gapic(self, mutation_list, error_dict=None): mock_fn = AsyncMock() if error_dict is None: error_dict = {} - mock_fn.side_effect = lambda *args, **kwargs: self._mock_stream( - mutation_list, error_dict - ) + responses = [ + MutateRowsResponse( + entries=[ + MutateRowsResponse.Entry( + index=idx, status=status_pb2.Status(code=error_dict.get(idx, 0)) + ) + ] + ) + for idx, _ in enumerate(mutation_list) + ] + mock_fn.return_value = mock_grpc_call(stream_response=responses) return mock_fn def test_ctor(self): @@ -86,6 +97,7 @@ def test_ctor(self): entries = [_make_mutation(), _make_mutation()] operation_timeout = 0.05 attempt_timeout = 0.01 + metrics = mock.Mock() retryable_exceptions = () instance = self._make_one( client, @@ -93,6 +105,7 @@ def test_ctor(self): entries, operation_timeout, attempt_timeout, + metrics, retryable_exceptions, ) # running gapic_fn should trigger a client call @@ -123,6 +136,7 @@ def test_ctor(self): assert instance.is_retryable(RuntimeError("")) is False assert instance.remaining_indices == list(range(len(entries))) assert instance.errors == {} + assert instance._operation_metrics == metrics def test_ctor_too_many_entries(self): """ @@ -139,8 +153,11 @@ def test_ctor_too_many_entries(self): entries = [_make_mutation()] * _MUTATE_ROWS_REQUEST_MUTATION_LIMIT operation_timeout = 0.05 attempt_timeout = 0.01 + metrics = mock.Mock() # no errors if at limit - self._make_one(client, table, entries, operation_timeout, attempt_timeout) + self._make_one( + client, table, entries, operation_timeout, attempt_timeout, metrics + ) # raise error after crossing with pytest.raises(ValueError) as e: self._make_one( @@ -149,6 +166,7 @@ def test_ctor_too_many_entries(self): entries + [_make_mutation()], operation_timeout, attempt_timeout, + metrics, ) assert "mutate_rows requests can contain at most 100000 mutations" in str( e.value @@ -169,7 +187,12 @@ async def test_mutate_rows_operation(self): f"{cls.__module__}.{cls.__name__}._run_attempt", AsyncMock() ) as attempt_mock: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() assert attempt_mock.call_count == 1 @@ -191,7 +214,12 @@ async def test_mutate_rows_attempt_exception(self, exc_type): found_exc = None try: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance._run_attempt() except Exception as e: @@ -227,7 +255,12 @@ async def test_mutate_rows_exception(self, exc_type): found_exc = None try: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() except MutationsExceptionGroup as e: @@ -270,6 +303,7 @@ async def test_mutate_rows_exception_retryable_eventually_pass(self, exc_type): entries, operation_timeout, operation_timeout, + mock.Mock(), retryable_exceptions=(exc_type,), ) await instance.start() @@ -294,17 +328,19 @@ async def test_mutate_rows_incomplete_ignored(self): AsyncMock(), ) as attempt_mock: attempt_mock.side_effect = _MutateRowsIncomplete("ignored") - found_exc = None - try: + with pytest.raises(MutationsExceptionGroup) as e: instance = self._make_one( - client, table, entries, operation_timeout, operation_timeout + client, + table, + entries, + operation_timeout, + operation_timeout, + mock.Mock(), ) await instance.start() - except MutationsExceptionGroup as e: - found_exc = e assert attempt_mock.call_count > 0 - assert len(found_exc.exceptions) == 1 - assert isinstance(found_exc.exceptions[0].__cause__, DeadlineExceeded) + assert len(e.value.exceptions) == 1 + assert isinstance(e.value.exceptions[0].__cause__, DeadlineExceeded) @pytest.mark.asyncio async def test_run_attempt_single_entry_success(self): diff --git a/tests/unit/data/_async/test__read_rows.py b/tests/unit/data/_async/test__read_rows.py index 4e7797c6d..31d6161d7 100644 --- a/tests/unit/data/_async/test__read_rows.py +++ b/tests/unit/data/_async/test__read_rows.py @@ -14,6 +14,7 @@ import pytest from google.cloud.bigtable.data._async._read_rows import _ReadRowsOperationAsync +from .test_client import mock_grpc_call # try/except added for compatibility with python < 3.8 try: @@ -60,6 +61,7 @@ def test_ctor(self): expected_operation_timeout = 42 expected_request_timeout = 44 time_gen_mock = mock.Mock() + metrics = mock.Mock() with mock.patch( "google.cloud.bigtable.data._async._read_rows._attempt_timeout_generator", time_gen_mock, @@ -69,6 +71,7 @@ def test_ctor(self): table, operation_timeout=expected_operation_timeout, attempt_timeout=expected_request_timeout, + metrics=metrics, ) assert time_gen_mock.call_count == 1 time_gen_mock.assert_called_once_with( @@ -87,6 +90,7 @@ def test_ctor(self): assert instance.request.table_name == table.table_name assert instance.request.app_profile_id == table.app_profile_id assert instance.request.rows_limit == row_limit + assert instance._operation_metrics == metrics @pytest.mark.parametrize( "in_keys,last_key,expected", @@ -228,31 +232,27 @@ async def test_revise_limit(self, start_limit, emit_num, expected_limit): from google.cloud.bigtable.data import ReadRowsQuery from google.cloud.bigtable_v2.types import ReadRowsResponse - async def awaitable_stream(): - async def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) + instance = self._make_one(query, table, 10, 10, mock.Mock()) assert instance._remaining_count == start_limit # read emit_num rows - async for val in instance.chunk_stream(awaitable_stream()): + chunks = [ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + for i in range(emit_num) + ] + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + async for val in instance.chunk_stream(stream): pass assert instance._remaining_count == expected_limit @@ -267,32 +267,28 @@ async def test_revise_limit_over_limit(self, start_limit, emit_num): from google.cloud.bigtable_v2.types import ReadRowsResponse from google.cloud.bigtable.data.exceptions import InvalidChunk - async def awaitable_stream(): - async def mock_stream(): - for i in range(emit_num): - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk( - row_key=str(i).encode(), - family_name="b", - qualifier=b"c", - value=b"d", - commit_row=True, - ) - ] - ) - - return mock_stream() - query = ReadRowsQuery(limit=start_limit) table = mock.Mock() table.table_name = "table_name" table.app_profile_id = "app_profile_id" - instance = self._make_one(query, table, 10, 10) + instance = self._make_one(query, table, 10, 10, mock.Mock()) assert instance._remaining_count == start_limit with pytest.raises(InvalidChunk) as e: # read emit_num rows - async for val in instance.chunk_stream(awaitable_stream()): + chunks = [ + ReadRowsResponse.CellChunk( + row_key=str(i).encode(), + family_name="b", + qualifier=b"c", + value=b"d", + commit_row=True, + ) + for i in range(emit_num) + ] + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + async for val in instance.chunk_stream(stream): pass assert "emit count exceeds row limit" in str(e.value) @@ -302,17 +298,12 @@ async def test_aclose(self): should be able to close a stream safely with aclose. Closed generators should raise StopAsyncIteration on next yield """ - - async def mock_stream(): - while True: - yield 1 - with mock.patch.object( _ReadRowsOperationAsync, "_read_rows_attempt" ) as mock_attempt: - instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1) - wrapped_gen = mock_stream() - mock_attempt.return_value = wrapped_gen + instance = self._make_one(mock.Mock(), mock.Mock(), 1, 1, mock.Mock()) + call = mock_grpc_call(stream_response=range(100)) + mock_attempt.return_value = call gen = instance.start_operation() # read one row await gen.__anext__() @@ -323,7 +314,7 @@ async def mock_stream(): await gen.aclose() # ensure close was propagated to wrapped generator with pytest.raises(StopAsyncIteration): - await wrapped_gen.__anext__() + await call.__anext__() @pytest.mark.asyncio async def test_retryable_ignore_repeated_rows(self): @@ -336,26 +327,14 @@ async def test_retryable_ignore_repeated_rows(self): row_key = b"duplicate" - async def mock_awaitable_stream(): - async def mock_stream(): - while True: - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - yield ReadRowsResponse( - chunks=[ - ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True) - ] - ) - - return mock_stream() - instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - stream = _ReadRowsOperationAsync.chunk_stream(instance, mock_awaitable_stream()) + chunks = [ReadRowsResponse.CellChunk(row_key=row_key, commit_row=True)] * 2 + grpc_call = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in chunks] + ) + stream = _ReadRowsOperationAsync.chunk_stream(instance, grpc_call) await stream.__anext__() with pytest.raises(InvalidChunk) as exc: await stream.__anext__() diff --git a/tests/unit/data/_async/test_client.py b/tests/unit/data/_async/test_client.py index a0019947d..444cbdcf3 100644 --- a/tests/unit/data/_async/test_client.py +++ b/tests/unit/data/_async/test_client.py @@ -45,6 +45,70 @@ ) +class mock_grpc_call: + """ + Used for mocking the responses from grpc calls. Can simulate both unary and streaming calls. + """ + + def __init__( + self, + unary_response=None, + stream_response=(), + sleep_time=0, + initial_metadata=grpc.aio.Metadata(), + trailing_metadata=grpc.aio.Metadata(), + ): + self.unary_response = unary_response + self.stream_response = stream_response + self.sleep_time = sleep_time + self.stream_idx = -1 + self._future = asyncio.get_event_loop().create_future() + self._future.set_result(unary_response) + self._initial_metadata = initial_metadata + self._trailing_metadata = trailing_metadata + + def __await__(self): + response = yield from self._future.__await__() + if response is None: + # await is a no-op for streaming calls + return self + # otherwise return unary response + return response + + def __aiter__(self): + return self + + async def __anext__(self): + self.stream_idx += 1 + if self.stream_idx < len(self.stream_response): + await asyncio.sleep(self.sleep_time) + next_val = self.stream_response[self.stream_idx] + if isinstance(next_val, Exception): + raise next_val + return next_val + raise StopAsyncIteration + + def cancel(self): + pass + + async def asend(self, val): + """ + implement generator protocol, so retries will treat this as a generator + i.e, call aclose at end of stream + """ + return await self.__anext__() + + async def aclose(self): + # simulate closing streams by jumping to the end + self.stream_idx = len(self.stream_response) + + async def trailing_metadata(self): + return self._trailing_metadata + + async def initial_metadata(self): + return self._initial_metadata + + def _make_client(*args, use_emulator=True, **kwargs): import os from google.cloud.bigtable.data._async.client import BigtableDataClientAsync @@ -1032,6 +1096,10 @@ class TestTableAsync: async def test_table_ctor(self): from google.cloud.bigtable.data._async.client import TableAsync from google.cloud.bigtable.data._async.client import _WarmedInstanceKey + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) + from google.cloud.bigtable.data._metrics import OpenTelemetryMetricsHandler expected_table_id = "table-id" expected_instance_id = "instance-id" @@ -1085,6 +1153,10 @@ async def test_table_ctor(self): table.default_mutate_rows_attempt_timeout == expected_mutate_rows_attempt_timeout ) + # check metrics object + assert isinstance(table._metrics, BigtableClientSideMetricsController) + assert len(table._metrics.handlers) == 1 + assert isinstance(table._metrics.handlers[0], OpenTelemetryMetricsHandler) # ensure task reaches completion await table._register_instance_task assert table._register_instance_task.done() @@ -1250,7 +1322,8 @@ async def test_customizable_retryable_errors( with mock.patch(retry_fn_path) as retry_fn_mock: async with _make_client() as client: table = client.get_table("instance-id", "table-id") - expected_predicate = lambda a: a in expected_retryables # noqa + expected_predicate = mock.Mock() + expected_predicate.side_effect = lambda exc: exc in expected_retryables retry_fn_mock.side_effect = RuntimeError("stop early") with mock.patch( "google.api_core.retry.if_exception_type" @@ -1266,7 +1339,13 @@ async def test_customizable_retryable_errors( ) retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor - assert retry_call_args[1] is expected_predicate + # note: may be wrapped by metrics + assert expected_predicate.call_count == 0 + found_predicate = retry_call_args[1] + obj = RuntimeError("test") + found_predicate(obj) + assert expected_predicate.call_count == 1 + assert expected_predicate.called_with(obj) @pytest.mark.parametrize( "fn_name,fn_args,gapic_fn", @@ -1348,6 +1427,7 @@ def _make_table(self, *args, **kwargs): ) client_mock._gapic_client.table_path.return_value = kwargs["table_id"] client_mock._gapic_client.instance_path.return_value = kwargs["instance_id"] + client_mock.project = "test-project" return TableAsync(client_mock, *args, **kwargs) def _make_stats(self): @@ -1385,31 +1465,11 @@ async def _make_gapic_stream( ): from google.cloud.bigtable_v2 import ReadRowsResponse - class mock_stream: - def __init__(self, chunk_list, sleep_time): - self.chunk_list = chunk_list - self.idx = -1 - self.sleep_time = sleep_time - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - if sleep_time: - await asyncio.sleep(self.sleep_time) - chunk = self.chunk_list[self.idx] - if isinstance(chunk, Exception): - raise chunk - else: - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list, sleep_time) + pb_list = [ + c if isinstance(c, Exception) else ReadRowsResponse(chunks=[c]) + for c in chunk_list + ] + return mock_grpc_call(stream_response=pb_list, sleep_time=sleep_time) async def execute_fn(self, table, *args, **kwargs): return await table.read_rows(*args, **kwargs) @@ -1482,17 +1542,16 @@ async def test_read_rows_timeout(self, operation_timeout): async with self._make_table() as table: read_rows = table.client._gapic_client.read_rows query = ReadRowsQuery() - chunks = [self._make_chunk(row_key=b"test_1")] + chunks = [core_exceptions.DeadlineExceeded("test timeout")] * 5 read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - chunks, sleep_time=1 + chunks, sleep_time=0.05 ) - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows(query, operation_timeout=operation_timeout) - except core_exceptions.DeadlineExceeded as e: - assert ( - e.message - == f"operation_timeout of {operation_timeout:0.1f}s exceeded" - ) + assert ( + e.value.message + == f"operation_timeout of {operation_timeout:0.1f}s exceeded" + ) @pytest.mark.parametrize( "per_request_t, operation_t, expected_num", @@ -1528,22 +1587,21 @@ async def test_read_rows_attempt_timeout( query = ReadRowsQuery() chunks = [core_exceptions.DeadlineExceeded("mock deadline")] - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows( query, operation_timeout=operation_t, attempt_timeout=per_request_t, ) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - if expected_num == 0: - assert retry_exc is None - else: - assert type(retry_exc) is RetryExceptionGroup - assert f"{expected_num} failed attempts" in str(retry_exc) - assert len(retry_exc.exceptions) == expected_num - for sub_exc in retry_exc.exceptions: - assert sub_exc.message == "mock deadline" + retry_exc = e.value.__cause__ + if expected_num == 0: + assert retry_exc is None + else: + assert type(retry_exc) is RetryExceptionGroup + assert f"{expected_num} failed attempts" in str(retry_exc) + assert len(retry_exc.exceptions) == expected_num + for sub_exc in retry_exc.exceptions: + assert sub_exc.message == "mock deadline" assert read_rows.call_count == expected_num # check timeouts for _, call_kwargs in read_rows.call_args_list[:-1]: @@ -1575,13 +1633,12 @@ async def test_read_rows_retryable_error(self, exc_type): ) query = ReadRowsQuery() expected_error = exc_type("mock error") - try: + with pytest.raises(core_exceptions.DeadlineExceeded) as e: await table.read_rows(query, operation_timeout=0.1) - except core_exceptions.DeadlineExceeded as e: - retry_exc = e.__cause__ - root_cause = retry_exc.exceptions[0] - assert type(root_cause) is exc_type - assert root_cause == expected_error + retry_exc = e.value.__cause__ + root_cause = retry_exc.exceptions[0] + assert type(root_cause) is exc_type + assert root_cause == expected_error @pytest.mark.parametrize( "exc_type", @@ -1599,17 +1656,16 @@ async def test_read_rows_retryable_error(self, exc_type): ) @pytest.mark.asyncio async def test_read_rows_non_retryable_error(self, exc_type): - async with self._make_table() as table: - read_rows = table.client._gapic_client.read_rows - read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( - [expected_error] - ) - query = ReadRowsQuery() - expected_error = exc_type("mock error") - try: - await table.read_rows(query, operation_timeout=0.1) - except exc_type as e: - assert e == expected_error + table = self._make_table() + read_rows = table.client._gapic_client.read_rows + read_rows.side_effect = lambda *args, **kwargs: self._make_gapic_stream( + [expected_error] + ) + query = ReadRowsQuery() + expected_error = exc_type("mock error") + with pytest.raises(exc_type) as e: + await table.read_rows(query, operation_timeout=0.1) + assert e.value == expected_error @pytest.mark.asyncio async def test_read_rows_revise_request(self): @@ -1636,15 +1692,14 @@ async def test_read_rows_revise_request(self): self._make_chunk(row_key=b"test_1"), core_exceptions.Aborted("mock retryable error"), ] - try: + with pytest.raises(InvalidChunk): await table.read_rows(query) - except InvalidChunk: - revise_rowset.assert_called() - first_call_kwargs = revise_rowset.call_args_list[0].kwargs - assert first_call_kwargs["row_set"] == query._to_pb(table).rows - assert first_call_kwargs["last_seen_row_key"] == b"test_1" - revised_call = read_rows.call_args_list[1].args[0] - assert revised_call.rows == return_val + revise_rowset.assert_called() + first_call_kwargs = revise_rowset.call_args_list[0].kwargs + assert first_call_kwargs["row_set"] == query._to_pb(table).rows + assert first_call_kwargs["last_seen_row_key"] == b"test_1" + revised_call = read_rows.call_args_list[1].args[0] + assert revised_call.rows == return_val @pytest.mark.asyncio async def test_read_rows_default_timeouts(self): @@ -1832,10 +1887,10 @@ class TestReadRowsSharded: @pytest.mark.asyncio async def test_read_rows_sharded_empty_query(self): async with _make_client() as client: - async with client.get_table("instance", "table") as table: - with pytest.raises(ValueError) as exc: - await table.read_rows_sharded([]) - assert "empty sharded_query" in str(exc.value) + table = client.get_table("instance", "table") + with pytest.raises(ValueError) as exc: + await table.read_rows_sharded([]) + assert "empty sharded_query" in str(exc.value) @pytest.mark.asyncio async def test_read_rows_sharded_multiple_queries(self): @@ -1987,11 +2042,14 @@ async def test_read_rows_sharded_batching(self): class TestSampleRowKeys: - async def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): + def _make_gapic_stream(self, sample_list: list[tuple[bytes, int]]): from google.cloud.bigtable_v2.types import SampleRowKeysResponse - for value in sample_list: - yield SampleRowKeysResponse(row_key=value[0], offset_bytes=value[1]) + pb_list = [ + SampleRowKeysResponse(row_key=s[0], offset_bytes=s[1]) for s in sample_list + ] + + return mock_grpc_call(stream_response=pb_list) @pytest.mark.asyncio async def test_sample_row_keys(self): @@ -2157,32 +2215,30 @@ async def test_mutate_row(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.return_value = None - await table.mutate_row( - "row_key", - mutation_arg, - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args_list[0].kwargs - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["row_key"] == b"row_key" - formatted_mutations = ( - [mutation._to_pb() for mutation in mutation_arg] - if isinstance(mutation_arg, list) - else [mutation_arg._to_pb()] - ) - assert kwargs["mutations"] == formatted_mutations - assert kwargs["timeout"] == expected_attempt_timeout - # make sure gapic layer is not retrying - assert kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.return_value = mock_grpc_call() + await table.mutate_row( + "row_key", + mutation_arg, + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args_list[0].kwargs + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["row_key"] == b"row_key" + formatted_mutations = ( + [mutation._to_pb() for mutation in mutation_arg] + if isinstance(mutation_arg, list) + else [mutation_arg._to_pb()] + ) + assert kwargs["mutations"] == formatted_mutations + assert kwargs["timeout"] == expected_attempt_timeout + # make sure gapic layer is not retrying + assert kwargs["retry"] is None @pytest.mark.parametrize( "retryable_exception", @@ -2197,20 +2253,16 @@ async def test_mutate_row_retryable_errors(self, retryable_exception): from google.cloud.bigtable.data.exceptions import RetryExceptionGroup async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(DeadlineExceeded) as e: - mutation = mutations.DeleteAllFromRow() - assert mutation.is_idempotent() is True - await table.mutate_row( - "row_key", mutation, operation_timeout=0.01 - ) - cause = e.value.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(DeadlineExceeded) as e: + mutation = mutations.DeleteAllFromRow() + assert mutation.is_idempotent() is True + await table.mutate_row("row_key", mutation, operation_timeout=0.01) + cause = e.value.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.parametrize( "retryable_exception", @@ -2227,19 +2279,13 @@ async def test_mutate_row_non_idempotent_retryable_errors( Non-idempotent mutations should not be retried """ async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(retryable_exception): - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - assert mutation.is_idempotent() is False - await table.mutate_row( - "row_key", mutation, operation_timeout=0.2 - ) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(retryable_exception): + mutation = mutations.SetCell("family", b"qualifier", b"value", -1) + assert mutation.is_idempotent() is False + await table.mutate_row("row_key", mutation, operation_timeout=0.2) @pytest.mark.parametrize( "non_retryable_exception", @@ -2255,46 +2301,18 @@ async def test_mutate_row_non_idempotent_retryable_errors( @pytest.mark.asyncio async def test_mutate_row_non_retryable_errors(self, non_retryable_exception): async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_row" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(non_retryable_exception): - mutation = mutations.SetCell( - "family", - b"qualifier", - b"value", - timestamp_micros=1234567890, - ) - assert mutation.is_idempotent() is True - await table.mutate_row( - "row_key", mutation, operation_timeout=0.2 - ) - - @pytest.mark.parametrize("include_app_profile", [True, False]) - @pytest.mark.asyncio - async def test_mutate_row_metadata(self, include_app_profile): - """request should attach metadata headers""" - profile = "profile" if include_app_profile else None - async with _make_client() as client: - async with client.get_table("i", "t", app_profile_id=profile) as table: - with mock.patch.object( - client._gapic_client, "mutate_row", AsyncMock() - ) as read_rows: - await table.mutate_row("rk", mock.Mock()) - kwargs = read_rows.call_args_list[0].kwargs - metadata = kwargs["metadata"] - goog_metadata = None - for key, value in metadata: - if key == "x-goog-request-params": - goog_metadata = value - assert goog_metadata is not None, "x-goog-request-params not found" - assert "table_name=" + table.table_name in goog_metadata - if include_app_profile: - assert "app_profile_id=profile" in goog_metadata - else: - assert "app_profile_id=" not in goog_metadata + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_row") as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(non_retryable_exception): + mutation = mutations.SetCell( + "family", + b"qualifier", + b"value", + timestamp_micros=1234567890, + ) + assert mutation.is_idempotent() is True + await table.mutate_row("row_key", mutation, operation_timeout=0.2) @pytest.mark.parametrize("mutations", [[], None]) @pytest.mark.asyncio @@ -2326,10 +2344,7 @@ async def _mock_response(self, response_list): for i in range(len(response_list)) ] - async def generator(): - yield MutateRowsResponse(entries=entries) - - return generator() + return mock_grpc_call(stream_response=[MutateRowsResponse(entries=entries)]) @pytest.mark.asyncio @pytest.mark.asyncio @@ -2356,25 +2371,23 @@ async def test_bulk_mutate_rows(self, mutation_arg): """Test mutations with no errors""" expected_attempt_timeout = 19 async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.return_value = self._mock_response([None]) - bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) - await table.bulk_mutate_rows( - [bulk_mutation], - attempt_timeout=expected_attempt_timeout, - ) - assert mock_gapic.call_count == 1 - kwargs = mock_gapic.call_args[1] - assert ( - kwargs["table_name"] - == "projects/project/instances/instance/tables/table" - ) - assert kwargs["entries"] == [bulk_mutation._to_pb()] - assert kwargs["timeout"] == expected_attempt_timeout - assert kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.return_value = self._mock_response([None]) + bulk_mutation = mutations.RowMutationEntry(b"row_key", mutation_arg) + await table.bulk_mutate_rows( + [bulk_mutation], + attempt_timeout=expected_attempt_timeout, + ) + assert mock_gapic.call_count == 1 + kwargs = mock_gapic.call_args[1] + assert ( + kwargs["table_name"] + == "projects/project/instances/instance/tables/table" + ) + assert kwargs["entries"] == [bulk_mutation._to_pb()] + assert kwargs["timeout"] == expected_attempt_timeout + assert kwargs["retry"] is None @pytest.mark.asyncio async def test_bulk_mutate_rows_multiple_entries(self): @@ -2467,24 +2480,22 @@ async def test_bulk_mutate_rows_idempotent_mutation_error_non_retryable( ) async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.DeleteAllFromRow() - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert "non-idempotent" not in str(failed_exception) - assert isinstance(failed_exception, FailedMutationEntryError) - cause = failed_exception.__cause__ - assert isinstance(cause, exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.DeleteAllFromRow() + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert "non-idempotent" not in str(failed_exception) + assert isinstance(failed_exception, FailedMutationEntryError) + cause = failed_exception.__cause__ + assert isinstance(cause, exception) @pytest.mark.parametrize( "retryable_exception", @@ -2507,25 +2518,23 @@ async def test_bulk_mutate_idempotent_retryable_request_errors( ) async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.05) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, RetryExceptionGroup) - assert isinstance(cause.exceptions[0], retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.05) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, RetryExceptionGroup) + assert isinstance(cause.exceptions[0], retryable_exception) @pytest.mark.asyncio @pytest.mark.parametrize( @@ -2545,26 +2554,22 @@ async def test_bulk_mutate_rows_non_idempotent_retryable_errors( ) async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = lambda *a, **k: self._mock_response( - [retryable_exception("mock")] - ) - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", -1 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is False - await table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = lambda *a, **k: self._mock_response( + [retryable_exception("mock")] + ) + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell("family", b"qualifier", b"value", -1) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is False + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, retryable_exception) @pytest.mark.parametrize( "non_retryable_exception", @@ -2587,24 +2592,22 @@ async def test_bulk_mutate_rows_non_retryable_errors(self, non_retryable_excepti ) async with _make_client(project="project") as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "mutate_rows" - ) as mock_gapic: - mock_gapic.side_effect = non_retryable_exception("mock") - with pytest.raises(MutationsExceptionGroup) as e: - mutation = mutations.SetCell( - "family", b"qualifier", b"value", timestamp_micros=123 - ) - entry = mutations.RowMutationEntry(b"row_key", [mutation]) - assert mutation.is_idempotent() is True - await table.bulk_mutate_rows([entry], operation_timeout=0.2) - assert len(e.value.exceptions) == 1 - failed_exception = e.value.exceptions[0] - assert isinstance(failed_exception, FailedMutationEntryError) - assert "non-idempotent" not in str(failed_exception) - cause = failed_exception.__cause__ - assert isinstance(cause, non_retryable_exception) + table = client.get_table("instance", "table") + with mock.patch.object(client._gapic_client, "mutate_rows") as mock_gapic: + mock_gapic.side_effect = non_retryable_exception("mock") + with pytest.raises(MutationsExceptionGroup) as e: + mutation = mutations.SetCell( + "family", b"qualifier", b"value", timestamp_micros=123 + ) + entry = mutations.RowMutationEntry(b"row_key", [mutation]) + assert mutation.is_idempotent() is True + await table.bulk_mutate_rows([entry], operation_timeout=0.2) + assert len(e.value.exceptions) == 1 + failed_exception = e.value.exceptions[0] + assert isinstance(failed_exception, FailedMutationEntryError) + assert "non-idempotent" not in str(failed_exception) + cause = failed_exception.__cause__ + assert isinstance(cause, non_retryable_exception) @pytest.mark.asyncio async def test_bulk_mutate_error_index(self): @@ -2696,8 +2699,8 @@ async def test_check_and_mutate(self, gapic_result): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=gapic_result + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=gapic_result) ) row_key = b"row_key" predicate = None @@ -2752,8 +2755,8 @@ async def test_check_and_mutate_single_mutations(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) true_mutation = SetCell("family", b"qualifier", b"value") false_mutation = SetCell("family", b"qualifier", b"value") @@ -2780,8 +2783,8 @@ async def test_check_and_mutate_predicate_object(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) await table.check_and_mutate_row( b"row_key", @@ -2808,8 +2811,8 @@ async def test_check_and_mutate_mutations_parsing(self): with mock.patch.object( client._gapic_client, "check_and_mutate_row" ) as mock_gapic: - mock_gapic.return_value = CheckAndMutateRowResponse( - predicate_matched=True + mock_gapic.return_value = mock_grpc_call( + CheckAndMutateRowResponse(predicate_matched=True) ) await table.check_and_mutate_row( b"row_key", @@ -2857,16 +2860,20 @@ async def test_read_modify_write_call_rule_args(self, call_rules, expected_rules """ Test that the gapic call is called with given rules """ + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + async with _make_client() as client: - async with client.get_table("instance", "table") as table: - with mock.patch.object( - client._gapic_client, "read_modify_write_row" - ) as mock_gapic: - await table.read_modify_write_row("key", call_rules) - assert mock_gapic.call_count == 1 - found_kwargs = mock_gapic.call_args_list[0][1] - assert found_kwargs["rules"] == expected_rules - assert found_kwargs["retry"] is None + table = client.get_table("instance", "table") + with mock.patch.object( + client._gapic_client, + "read_modify_write_row", + ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call(ReadModifyWriteRowResponse()) + await table.read_modify_write_row("key", call_rules) + assert mock_gapic.call_count == 1 + found_kwargs = mock_gapic.call_args_list[0][1] + assert found_kwargs["rules"] == expected_rules + assert found_kwargs["retry"] is None @pytest.mark.parametrize("rules", [[], None]) @pytest.mark.asyncio @@ -2879,6 +2886,8 @@ async def test_read_modify_write_no_rules(self, rules): @pytest.mark.asyncio async def test_read_modify_write_call_defaults(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + instance = "instance1" table_id = "table1" project = "project1" @@ -2888,6 +2897,9 @@ async def test_read_modify_write_call_defaults(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] @@ -2901,6 +2913,8 @@ async def test_read_modify_write_call_defaults(self): @pytest.mark.asyncio async def test_read_modify_write_call_overrides(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + row_key = b"row_key1" expected_timeout = 12345 profile_id = "profile1" @@ -2911,6 +2925,9 @@ async def test_read_modify_write_call_overrides(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row( row_key, mock.Mock(), @@ -2924,12 +2941,17 @@ async def test_read_modify_write_call_overrides(self): @pytest.mark.asyncio async def test_read_modify_write_string_key(self): + from google.cloud.bigtable_v2.types import ReadModifyWriteRowResponse + row_key = "string_row_key1" async with _make_client() as client: async with client.get_table("instance", "table_id") as table: with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call( + ReadModifyWriteRowResponse() + ) await table.read_modify_write_row(row_key, mock.Mock()) assert mock_gapic.call_count == 1 kwargs = mock_gapic.call_args_list[0][1] @@ -2950,8 +2972,8 @@ async def test_read_modify_write_row_building(self): with mock.patch.object( client._gapic_client, "read_modify_write_row" ) as mock_gapic: + mock_gapic.return_value = mock_grpc_call(mock_response) with mock.patch.object(Row, "_from_pb") as constructor_mock: - mock_gapic.return_value = mock_response await table.read_modify_write_row("key", mock.Mock()) assert constructor_mock.call_count == 1 constructor_mock.assert_called_once_with(mock_response.row) diff --git a/tests/unit/data/_async/test_mutations_batcher.py b/tests/unit/data/_async/test_mutations_batcher.py index cca7c9824..b1237ca38 100644 --- a/tests/unit/data/_async/test_mutations_batcher.py +++ b/tests/unit/data/_async/test_mutations_batcher.py @@ -903,7 +903,8 @@ async def test__execute_mutate_rows(self, mutate_rows): table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) + mock_metric = mock.Mock() + result = await instance._execute_mutate_rows(batch, mock_metric) assert start_operation.call_count == 1 args, kwargs = mutate_rows.call_args assert args[0] == table.client._gapic_client @@ -911,6 +912,7 @@ async def test__execute_mutate_rows(self, mutate_rows): assert args[2] == batch kwargs["operation_timeout"] == 17 kwargs["attempt_timeout"] == 13 + kwargs["metrics"] == mock_metric assert result == [] @pytest.mark.asyncio @@ -933,7 +935,7 @@ async def test__execute_mutate_rows_returns_errors(self, mutate_rows): table.default_mutate_rows_retryable_errors = () async with self._make_one(table) as instance: batch = [_make_mutation()] - result = await instance._execute_mutate_rows(batch) + result = await instance._execute_mutate_rows(batch, mock.Mock()) assert len(result) == 2 assert result[0] == err1 assert result[1] == err2 @@ -1058,7 +1060,7 @@ async def test_timeout_args_passed(self, mutate_rows): assert instance._operation_timeout == expected_operation_timeout assert instance._attempt_timeout == expected_attempt_timeout # make simulated gapic call - await instance._execute_mutate_rows([_make_mutation()]) + await instance._execute_mutate_rows([_make_mutation()], mock.Mock()) assert mutate_rows.call_count == 1 kwargs = mutate_rows.call_args[1] assert kwargs["operation_timeout"] == expected_operation_timeout @@ -1174,7 +1176,8 @@ async def test_customizable_retryable_errors( predicate_builder_mock.return_value = expected_predicate retry_fn_mock.side_effect = RuntimeError("stop early") mutation = _make_mutation(count=1, size=1) - await instance._execute_mutate_rows([mutation]) + predicate_builder_mock.reset_mock() + await instance._execute_mutate_rows([mutation], mock.Mock()) # passed in errors should be used to build the predicate predicate_builder_mock.assert_called_once_with( *expected_retryables, _MutateRowsIncomplete @@ -1182,3 +1185,37 @@ async def test_customizable_retryable_errors( retry_call_args = retry_fn_mock.call_args_list[0].args # output of if_exception_type should be sent in to retry constructor assert retry_call_args[1] is expected_predicate + + @pytest.mark.asyncio + @pytest.mark.parametrize("sleep_time,flow_size", [(0, 10), (0.1, 1), (0.01, 10)]) + async def test_flow_throttling_metric(self, sleep_time, flow_size): + """ + When there are delays due to waiting on flow control, + should be reflected in operation metric's flow_throttling_time + """ + import time + from google.cloud.bigtable.data._metrics import ( + BigtableClientSideMetricsController, + ) + from google.cloud.bigtable.data._metrics import ActiveOperationMetric + + # create mock call + async def mock_add_to_flow(): + time.sleep(sleep_time) + for _ in range(flow_size): + await asyncio.sleep(0) + yield mock.Mock() + + mock_instance = mock.Mock() + mock_instance._wait_for_batch_results.return_value = asyncio.sleep(0) + mock_instance._entries_processed_since_last_raise = 0 + mock_instance._table._metrics = BigtableClientSideMetricsController([]) + mock_instance._flow_control.add_to_flow.return_value = mock_add_to_flow() + await self._get_target_class()._flush_internal(mock_instance, []) + # get list of metrics + mock_bg_task = mock_instance._create_bg_task + metric_list = [arg[0][-1] for arg in mock_bg_task.call_args_list] + # make sure operations were set up as expected + assert len(metric_list) == flow_size + assert all([isinstance(m, ActiveOperationMetric) for m in metric_list]) + assert abs(metric_list[0].flow_throttling_time - sleep_time) < 0.002 diff --git a/tests/unit/data/_metrics/handlers/test_handler_gcp_exporter.py b/tests/unit/data/_metrics/handlers/test_handler_gcp_exporter.py index c353cd681..72d399bf9 100644 --- a/tests/unit/data/_metrics/handlers/test_handler_gcp_exporter.py +++ b/tests/unit/data/_metrics/handlers/test_handler_gcp_exporter.py @@ -100,28 +100,6 @@ def test_uses_custom_meter_provider(self, mock_meter_provider, mock_instruments) assert isinstance(found_reader._exporter, _BigtableMetricsExporter) assert found_reader._exporter.project_name == f"projects/{project_id}" - @mock.patch( - "google.cloud.bigtable.data._metrics.handlers.gcp_exporter.PeriodicExportingMetricReader", - autospec=True, - ) - def test_custom_export_interval(self, mock_reader): - """ - should be able to set a custom export interval - """ - input_interval = 123 - try: - self._make_one( - export_interval=input_interval, - project_id="p", - instance_id="i", - table_id="t", - ) - except Exception: - pass - reader_init_kwargs = mock_reader.call_args[1] - found_interval = reader_init_kwargs["export_interval_millis"] - assert found_interval == input_interval * 1000 # convert to ms - class Test_BigtableMetricsExporter: def _get_class(self): diff --git a/tests/unit/data/_metrics/handlers/test_handler_opentelemetry.py b/tests/unit/data/_metrics/handlers/test_handler_opentelemetry.py index 94c1acc3b..20e55324c 100644 --- a/tests/unit/data/_metrics/handlers/test_handler_opentelemetry.py +++ b/tests/unit/data/_metrics/handlers/test_handler_opentelemetry.py @@ -123,6 +123,7 @@ def test__generate_client_uid(self): Should generate a unique id with format `python-@` """ import re + instance = self._make_one() # test with random values uid = instance._generate_client_uid() @@ -141,7 +142,6 @@ def test__generate_client_uid(self): uid = instance._generate_client_uid() assert uid == "python-uuid-@localhost" - @pytest.mark.parametrize( "metric_name,kind,optional_labels", [ diff --git a/tests/unit/data/_metrics/test_data_model.py b/tests/unit/data/_metrics/test_data_model.py index dedc0b494..3caf81d9d 100644 --- a/tests/unit/data/_metrics/test_data_model.py +++ b/tests/unit/data/_metrics/test_data_model.py @@ -601,7 +601,7 @@ def test_end_on_empty_operation(self): assert final_op.final_status == StatusCode.OK assert final_op.completed_attempts == [] - def test_build_wrapped_predicate(self): + def test_build_wrapped_fn_handlers_predicate(self): """ predicate generated by object should terminate attempt or operation based on passed in predicate @@ -610,23 +610,61 @@ def test_build_wrapped_predicate(self): cls = type(self._make_one(object())) # ensure predicate is called with the exception mock_predicate = mock.Mock() - cls.build_wrapped_predicate(mock.Mock(), mock_predicate)(input_exc) + pred, _ = cls.build_wrapped_fn_handlers(mock.Mock(), mock_predicate) + pred(input_exc) assert mock_predicate.call_count == 1 assert mock_predicate.call_args[0][0] == input_exc assert len(mock_predicate.call_args[0]) == 1 # if predicate is true, end the attempt mock_instance = mock.Mock() - cls.build_wrapped_predicate(mock_instance, lambda x: True)(input_exc) + pred, _ = cls.build_wrapped_fn_handlers(mock_instance, lambda x: True) + pred(input_exc) assert mock_instance.end_attempt_with_status.call_count == 1 assert mock_instance.end_attempt_with_status.call_args[0][0] == input_exc assert len(mock_instance.end_attempt_with_status.call_args[0]) == 1 # if predicate is false, end the operation mock_instance = mock.Mock() - cls.build_wrapped_predicate(mock_instance, lambda x: False)(input_exc) + pred, _ = cls.build_wrapped_fn_handlers(mock_instance, lambda x: False) + pred(input_exc) assert mock_instance.end_with_status.call_count == 1 assert mock_instance.end_with_status.call_args[0][0] == input_exc assert len(mock_instance.end_with_status.call_args[0]) == 1 + def test_build_wrapped_fn_handlers_exc_factory(self): + """ + exception factory generated by object should terminate operation + on timeout + """ + from google.api_core.retry import RetryFailureReason + from google.api_core.exceptions import DeadlineExceeded + + cls = type(self._make_one(object())) + # ensure inner factory is called with the exception + _, factory = cls.build_wrapped_fn_handlers(mock.Mock(), None) + with mock.patch( + "google.cloud.bigtable.data._metrics.data_model._retry_exception_factory" + ) as mock_factory: + expected_return = (object(), object()) + mock_factory.return_value = expected_return + args = ("a", "b", "c") + got_result = factory(*args) + assert expected_return == got_result + assert mock_factory.call_count == 1 + assert mock_factory.call_args[0] == args + + # if called with reason == TIMEOUT, end the operation + mock_instance = mock.Mock() + _, factory = cls.build_wrapped_fn_handlers(mock_instance, None) + factory([], RetryFailureReason.TIMEOUT, None) + assert mock_instance.end_with_status.call_count == 1 + assert type(mock_instance.end_with_status.call_args[0][0]) == DeadlineExceeded + + # if called with reason==NON_RETRYABLE_ERROR, do not + mock_instance = mock.Mock() + _, factory = cls.build_wrapped_fn_handlers(mock_instance, None) + factory([], RetryFailureReason.NON_RETRYABLE_ERROR, None) + assert mock_instance.end_with_status.call_count == 0 + def test__exc_to_status(self): """ Should return grpc_status_code if grpc error, otherwise UNKNOWN @@ -671,15 +709,22 @@ def test__exc_to_status(self): for exc in custom_excs: assert cls._exc_to_status(exc) == cause_exc.grpc_status_code, exc # extract most recent exception for bigtable exception groups + # if retry is cause, unwrap retry + retry_exc = bt_exc.RetryExceptionGroup([RuntimeError(), cause_exc]) exc_groups = [ - bt_exc._BigtableExceptionGroup("", [ValueError(), cause_exc]), - bt_exc.RetryExceptionGroup([RuntimeError(), cause_exc]), + retry_exc, bt_exc.ShardedReadRowsExceptionGroup( [bt_exc.FailedQueryShardError(1, {}, cause=cause_exc)], [], 2 ), + bt_exc.ShardedReadRowsExceptionGroup( + [bt_exc.FailedQueryShardError(1, {}, cause=retry_exc)], [], 2 + ), bt_exc.MutationsExceptionGroup( [bt_exc.FailedMutationEntryError(1, mock.Mock(), cause=cause_exc)], 2 ), + bt_exc.MutationsExceptionGroup( + [bt_exc.FailedMutationEntryError(1, mock.Mock(), cause=retry_exc)], 2 + ), ] for exc in exc_groups: assert cls._exc_to_status(exc) == cause_exc.grpc_status_code, exc @@ -791,6 +836,27 @@ async def test_wrap_attempt_fn_success(self): assert len(metric.completed_attempts) == 1 assert metric.completed_attempts[0].end_status == StatusCode.OK + @pytest.mark.asyncio + async def test_wrap_attempt_fn_success_extract_call_metadata(self): + """ + When extract_call_metadata is True, should call add_response_metadata + on operation with output of wrapped function + """ + from .._async.test_client import mock_grpc_call + + metric = self._make_one(object()) + async with metric as context: + mock_call = mock_grpc_call() + inner_fn = lambda *args, **kwargs: mock_call # noqa + wrapped_fn = context.wrap_attempt_fn(inner_fn, extract_call_metadata=True) + with mock.patch.object( + metric, "add_response_metadata" + ) as mock_add_metadata: + # make the wrapped call + result = await wrapped_fn() + assert result == mock_call + assert mock_add_metadata.call_count == 1 + @pytest.mark.asyncio async def test_wrap_attempt_fn_failed_extract_call_metadata(self): """ diff --git a/tests/unit/data/_metrics/test_rpcs_instrumented.py b/tests/unit/data/_metrics/test_rpcs_instrumented.py new file mode 100644 index 000000000..6dc84f124 --- /dev/null +++ b/tests/unit/data/_metrics/test_rpcs_instrumented.py @@ -0,0 +1,315 @@ +# Copyright 2023 Google LLC +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +""" +This file tests each rpc method to ensure they support metrics properly +""" + +import pytest +import mock +import datetime +from grpc import StatusCode +from grpc.aio import Metadata + +from google.cloud.bigtable.data.read_rows_query import ReadRowsQuery +from google.cloud.bigtable.data import mutations +from google.cloud.bigtable.data._metrics import OperationType +from google.cloud.bigtable.data._metrics.data_model import BIGTABLE_METADATA_KEY +from google.cloud.bigtable.data._metrics.data_model import SERVER_TIMING_METADATA_KEY + +from .._async.test_client import mock_grpc_call + + +RPC_ARGS = "fn_name,fn_args,gapic_fn,is_unary,expected_type" +RETRYABLE_RPCS = [ + ( + "read_rows_stream", + (ReadRowsQuery(),), + "read_rows", + False, + OperationType.READ_ROWS, + ), + ("read_rows", (ReadRowsQuery(),), "read_rows", False, OperationType.READ_ROWS), + ("read_row", (b"row_key",), "read_rows", False, OperationType.READ_ROWS), + ( + "read_rows_sharded", + ([ReadRowsQuery()],), + "read_rows", + False, + OperationType.READ_ROWS, + ), + ("row_exists", (b"row_key",), "read_rows", False, OperationType.READ_ROWS), + ("sample_row_keys", (), "sample_row_keys", False, OperationType.SAMPLE_ROW_KEYS), + ( + "mutate_row", + (b"row_key", [mutations.DeleteAllFromRow()]), + "mutate_row", + False, + OperationType.MUTATE_ROW, + ), + ( + "bulk_mutate_rows", + ([mutations.RowMutationEntry(b"key", [mutations.DeleteAllFromRow()])],), + "mutate_rows", + False, + OperationType.BULK_MUTATE_ROWS, + ), +] +ALL_RPCS = RETRYABLE_RPCS + [ + ( + "check_and_mutate_row", + (b"row_key", None), + "check_and_mutate_row", + True, + OperationType.CHECK_AND_MUTATE, + ), + ( + "read_modify_write_row", + (b"row_key", mock.Mock()), + "read_modify_write_row", + True, + OperationType.READ_MODIFY_WRITE, + ), +] + + +@pytest.mark.parametrize(RPC_ARGS, ALL_RPCS) +@pytest.mark.asyncio +async def test_rpc_instrumented(fn_name, fn_args, gapic_fn, is_unary, expected_type): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + from google.cloud.bigtable_v2.types import ResponseParams + + cluster_data = "my-cluster" + zone_data = "my-zone" + expected_gfe_latency_ms = 123 + + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}" + ) as gapic_mock: + if is_unary: + unary_response = mock.Mock() + unary_response.row.families = [] # patch for read_modify_write_row + else: + unary_response = None + # populate metadata fields + initial_metadata = Metadata( + ( + BIGTABLE_METADATA_KEY, + ResponseParams.serialize( + ResponseParams(zone_id=zone_data, cluster_id=cluster_data) + ), + ) + ) + trailing_metadata = Metadata( + (SERVER_TIMING_METADATA_KEY, f"gfet4t7; dur={expected_gfe_latency_ms}") + ) + grpc_call = mock_grpc_call( + unary_response=unary_response, + initial_metadata=initial_metadata, + trailing_metadata=trailing_metadata, + ) + gapic_mock.return_value = grpc_call + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args) + # iterate stream if it exists + try: + [i async for i in maybe_stream] + except TypeError: + pass + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == expected_type + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration_ms < 100 + assert found_operation.duration_ms > 0 + assert found_operation.final_status == StatusCode.OK + assert found_operation.cluster_id == cluster_data + assert found_operation.zone == zone_data + # is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded + assert found_operation.is_streaming == ("read_rows" in fn_name) + # check attempts + assert len(found_operation.completed_attempts) == 1 + found_attempt = found_operation.completed_attempts[0] + assert found_attempt.end_status == StatusCode.OK + assert found_attempt.start_time - now < datetime.timedelta(seconds=1) + assert found_attempt.duration_ms < 100 + assert found_attempt.duration_ms > 0 + assert found_attempt.start_time >= found_operation.start_time + assert found_attempt.duration_ms <= found_operation.duration_ms + assert found_attempt.gfe_latency_ms == expected_gfe_latency_ms + # first response latency not populated, because no real read_rows chunks processed + assert found_attempt.first_response_latency_ms is None + # no application blocking time or backoff time expected + assert found_attempt.application_blocking_time_ms == 0 + assert found_attempt.backoff_before_attempt_ms == 0 + # no throttling expected + assert found_attempt.grpc_throttling_time_ms == 0 + assert found_operation.flow_throttling_time_ms == 0 + + +@pytest.mark.parametrize(RPC_ARGS, RETRYABLE_RPCS) +@pytest.mark.asyncio +async def test_rpc_instrumented_multiple_attempts( + fn_name, fn_args, gapic_fn, is_unary, expected_type +): + """check that all requests attach proper metadata headers, with a retry""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + from google.api_core.exceptions import Aborted + from google.cloud.bigtable_v2.types import MutateRowsResponse + from google.rpc.status_pb2 import Status + + with mock.patch( + f"google.cloud.bigtable_v2.BigtableAsyncClient.{gapic_fn}" + ) as gapic_mock: + if is_unary: + unary_response = mock.Mock() + unary_response.row.families = [] # patch for read_modify_write_row + else: + unary_response = None + grpc_call = mock_grpc_call(unary_response=unary_response) + if gapic_fn == "mutate_rows": + # patch response to send success + grpc_call.stream_response = [ + MutateRowsResponse( + entries=[MutateRowsResponse.Entry(index=0, status=Status(code=0))] + ) + ] + gapic_mock.side_effect = [Aborted("first attempt failed"), grpc_call] + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + test_fn = table.__getattribute__(fn_name) + maybe_stream = await test_fn(*fn_args, retryable_errors=(Aborted,)) + # iterate stream if it exists + try: + [_ async for _ in maybe_stream] + except TypeError: + pass + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == expected_type + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration_ms < 100 + assert found_operation.duration_ms > 0 + assert found_operation.final_status == StatusCode.OK + # metadata wasn't set, should see default values + assert found_operation.cluster_id == "unspecified" + assert found_operation.zone == "global" + # is_streaming should only be true for read_rows, read_rows_stream, and read_rows_sharded + assert found_operation.is_streaming == ("read_rows" in fn_name) + # check attempts + assert len(found_operation.completed_attempts) == 2 + failure, success = found_operation.completed_attempts + for attempt in [success, failure]: + # check things that should be consistent across attempts + assert attempt.start_time - now < datetime.timedelta(seconds=1) + assert attempt.duration_ms < 100 + assert attempt.duration_ms > 0 + assert attempt.start_time >= found_operation.start_time + assert attempt.duration_ms <= found_operation.duration_ms + assert attempt.application_blocking_time_ms == 0 + assert success.end_status == StatusCode.OK + assert failure.end_status == StatusCode.ABORTED + assert success.start_time > failure.start_time + datetime.timedelta( + milliseconds=failure.duration_ms + ) + assert success.backoff_before_attempt_ms > 0 + assert failure.backoff_before_attempt_ms == 0 + + +@pytest.mark.asyncio +async def test_batcher_rpcs_instrumented(): + """check that all requests attach proper metadata headers""" + from google.cloud.bigtable.data import TableAsync + from google.cloud.bigtable.data import BigtableDataClientAsync + from google.cloud.bigtable_v2.types import ResponseParams + + cluster_data = "my-cluster" + zone_data = "my-zone" + expected_gfe_latency_ms = 123 + + with mock.patch( + "google.cloud.bigtable_v2.BigtableAsyncClient.mutate_rows" + ) as gapic_mock: + # populate metadata fields + initial_metadata = Metadata( + ( + BIGTABLE_METADATA_KEY, + ResponseParams.serialize( + ResponseParams(zone_id=zone_data, cluster_id=cluster_data) + ), + ) + ) + trailing_metadata = Metadata( + (SERVER_TIMING_METADATA_KEY, f"gfet4t7; dur={expected_gfe_latency_ms}") + ) + grpc_call = mock_grpc_call( + initial_metadata=initial_metadata, trailing_metadata=trailing_metadata + ) + gapic_mock.return_value = grpc_call + async with BigtableDataClientAsync() as client: + table = TableAsync(client, "instance-id", "table-id") + # customize metrics handlers + mock_metric_handler = mock.Mock() + table._metrics.handlers = [mock_metric_handler] + async with table.mutations_batcher() as batcher: + await batcher.append( + mutations.RowMutationEntry( + b"row-key", [mutations.DeleteAllFromRow()] + ) + ) + # check for recorded metrics values + assert mock_metric_handler.on_operation_complete.call_count == 1 + found_operation = mock_metric_handler.on_operation_complete.call_args[0][0] + # make sure expected fields were set properly + assert found_operation.op_type == OperationType.BULK_MUTATE_ROWS + now = datetime.datetime.now(datetime.timezone.utc) + assert found_operation.start_time - now < datetime.timedelta(seconds=1) + assert found_operation.duration_ms < 100 + assert found_operation.duration_ms > 0 + assert found_operation.final_status == StatusCode.OK + assert found_operation.cluster_id == cluster_data + assert found_operation.zone == zone_data + assert found_operation.is_streaming is False + # check attempts + assert len(found_operation.completed_attempts) == 1 + found_attempt = found_operation.completed_attempts[0] + assert found_attempt.end_status == StatusCode.OK + assert found_attempt.start_time - now < datetime.timedelta(seconds=1) + assert found_attempt.duration_ms < 100 + assert found_attempt.duration_ms > 0 + assert found_attempt.start_time >= found_operation.start_time + assert found_attempt.duration_ms <= found_operation.duration_ms + assert found_attempt.gfe_latency_ms == expected_gfe_latency_ms + # first response latency not populated, because no real read_rows chunks processed + assert found_attempt.first_response_latency_ms is None + # no application blocking time or backoff time expected + assert found_attempt.application_blocking_time_ms == 0 + assert found_attempt.backoff_before_attempt_ms == 0 diff --git a/tests/unit/data/test__helpers.py b/tests/unit/data/test__helpers.py index 5a9c500ed..61299b768 100644 --- a/tests/unit/data/test__helpers.py +++ b/tests/unit/data/test__helpers.py @@ -1,3 +1,4 @@ +# Copyright 2023 Google LLC # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/tests/unit/data/test_read_rows_acceptance.py b/tests/unit/data/test_read_rows_acceptance.py index 7cb3c08dc..7a43581fa 100644 --- a/tests/unit/data/test_read_rows_acceptance.py +++ b/tests/unit/data/test_read_rows_acceptance.py @@ -27,6 +27,7 @@ from google.cloud.bigtable.data.row import Row from ..v2_client.test_row_merger import ReadRowsTest, TestFile +from ._async.test_client import mock_grpc_call def parse_readrows_acceptance_tests(): @@ -60,19 +61,20 @@ def extract_results_from_row(row: Row): ) @pytest.mark.asyncio async def test_row_merger_scenario(test_case: ReadRowsTest): - async def _scenerio_stream(): - for chunk in test_case.chunks: - yield ReadRowsResponse(chunks=[chunk]) + from google.cloud.bigtable.data._metrics.data_model import ActiveOperationMetric try: results = [] instance = mock.Mock() instance._last_yielded_row_key = None instance._remaining_count = None - chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_scenerio_stream()) + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=test_case.chunks)] ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + chunker = _ReadRowsOperationAsync.chunk_stream(instance, stream) + metric = ActiveOperationMetric(0) + metric.start_attempt() + merger = _ReadRowsOperationAsync.merge_rows(chunker, metric) async for row in merger: for cell in row: cell_result = ReadRowsTest.Result( @@ -95,38 +97,20 @@ async def _scenerio_stream(): ) @pytest.mark.asyncio async def test_read_rows_scenario(test_case: ReadRowsTest): - async def _make_gapic_stream(chunk_list: list[ReadRowsResponse]): - from google.cloud.bigtable_v2 import ReadRowsResponse - - class mock_stream: - def __init__(self, chunk_list): - self.chunk_list = chunk_list - self.idx = -1 - - def __aiter__(self): - return self - - async def __anext__(self): - self.idx += 1 - if len(self.chunk_list) > self.idx: - chunk = self.chunk_list[self.idx] - return ReadRowsResponse(chunks=[chunk]) - raise StopAsyncIteration - - def cancel(self): - pass - - return mock_stream(chunk_list) - try: with mock.patch.dict(os.environ, {"BIGTABLE_EMULATOR_HOST": "localhost"}): # use emulator mode to avoid auth issues in CI client = BigtableDataClientAsync() table = client.get_table("instance", "table") results = [] - with mock.patch.object(table.client._gapic_client, "read_rows") as read_rows: + with mock.patch.object( + table.client._gapic_client, "read_rows", mock.AsyncMock() + ) as read_rows: # run once, then return error on retry - read_rows.return_value = _make_gapic_stream(test_case.chunks) + stream = mock_grpc_call( + stream_response=[ReadRowsResponse(chunks=[c]) for c in test_case.chunks] + ) + read_rows.return_value = stream async for row in await table.read_rows_stream(query={}): for cell in row: cell_result = ReadRowsTest.Result( @@ -148,16 +132,14 @@ def cancel(self): @pytest.mark.asyncio async def test_out_of_order_rows(): - async def _row_stream(): - yield ReadRowsResponse(last_scanned_row_key=b"a") - instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = b"b" chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) + instance, + mock_grpc_call(stream_response=[ReadRowsResponse(last_scanned_row_key=b"a")]), ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = _ReadRowsOperationAsync.merge_rows(chunker, mock.Mock()) with pytest.raises(InvalidChunk): async for _ in merger: pass @@ -310,21 +292,14 @@ async def test_mid_cell_labels_change(): ) -async def _coro_wrapper(stream): - return stream - - async def _process_chunks(*chunks): - async def _row_stream(): - yield ReadRowsResponse(chunks=chunks) - instance = mock.Mock() instance._remaining_count = None instance._last_yielded_row_key = None chunker = _ReadRowsOperationAsync.chunk_stream( - instance, _coro_wrapper(_row_stream()) + instance, mock_grpc_call(stream_response=[ReadRowsResponse(chunks=chunks)]) ) - merger = _ReadRowsOperationAsync.merge_rows(chunker) + merger = _ReadRowsOperationAsync.merge_rows(chunker, mock.Mock()) results = [] async for row in merger: results.append(row)