diff --git a/bigtable/google/cloud/bigtable/retry.py b/bigtable/google/cloud/bigtable/retry.py new file mode 100644 index 000000000000..2486f8b20fa9 --- /dev/null +++ b/bigtable/google/cloud/bigtable/retry.py @@ -0,0 +1,169 @@ +"""Provides function wrappers that implement retrying.""" +import random +import time +import six + +from google.cloud._helpers import _to_bytes +from google.cloud.bigtable._generated import ( + bigtable_pb2 as data_messages_v2_pb2) +from google.gax import config, errors + +_MILLIS_PER_SECOND = 1000 + +def _has_timeout_settings(backoff_settings): + return (backoff_settings.rpc_timeout_multiplier is not None and + backoff_settings.max_rpc_timeout_millis is not None and + backoff_settings.total_timeout_millis is not None and + backoff_settings.initial_rpc_timeout_millis is not None) + +class ReadRowsIterator(): + """Creates an iterator equivalent to a_iter, but that retries on certain + exceptions. + """ + + def __init__(self, client, name, start_key, end_key, filter_, limit, + retry_options, **kwargs): + self.client = client + self.retry_options = retry_options + self.name = name + self.start_key = start_key + self.start_key_closed = True + self.end_key = end_key + self.filter_ = filter_ + self.limit = limit + + self.delay_mult = retry_options.backoff_settings.retry_delay_multiplier + self.max_delay_millis = retry_options.backoff_settings.max_retry_delay_millis + self.has_timeout_settings = _has_timeout_settings(retry_options.backoff_settings) + + if self.has_timeout_settings: + self.timeout_mult = retry_options.backoff_settings.rpc_timeout_multiplier + self.max_timeout = (retry_options.backoff_settings.max_rpc_timeout_millis / _MILLIS_PER_SECOND) + self.total_timeout = (retry_options.backoff_settings.total_timeout_millis / _MILLIS_PER_SECOND) + self.set_stream() + + def set_start_key(self, start_key): + self.start_key = start_key + self.start_key_closed = False + + def set_stream(self): + request_pb = _create_row_request( + self.name, start_key=self.start_key, + start_key_closed=self.start_key_closed, end_key=self.end_key, + filter_= self.filter_, limit=self.limit) + self.stream = self.client._data_stub.ReadRows(request_pb) + + def next(self, *args, **kwargs): + delay = self.retry_options.backoff_settings.initial_retry_delay_millis + exc = errors.RetryError('Retry total timeout exceeded before any' + 'response was received') + if self.has_timeout_settings: + timeout = ( + self.retry_options.backoff_settings.initial_rpc_timeout_millis / + _MILLIS_PER_SECOND) + + now = time.time() + deadline = now + self.total_timeout + else: + timeout = None + deadline = None + + while deadline is None or now < deadline: + try: + return six.next(self.stream) + except StopIteration as stop: + raise stop + except Exception as exception: # pylint: disable=broad-except + code = config.exc_to_code(exception) + if code not in self.retry_options.retry_codes: + raise errors.RetryError( + 'Exception occurred in retry method that was not' + ' classified as transient', exception) + + # pylint: disable=redefined-variable-type + exc = errors.RetryError( + 'Retry total timeout exceeded with exception', exception) + + # Sleep a random number which will, on average, equal the + # expected delay. + to_sleep = random.uniform(0, delay * 2) + time.sleep(to_sleep / _MILLIS_PER_SECOND) + delay = min(delay * self.delay_mult, self.max_delay_millis) + + if self.has_timeout_settings: + now = time.time() + timeout = min( + timeout * self.timeout_mult, self.max_timeout, deadline - now) + self.set_stream() + + raise exc + + def __next__(self, *args, **kwargs): + return self.next(*args, **kwargs) + + def __iter__(self): + return self + +def _create_row_request(table_name, row_key=None, start_key=None, + start_key_closed=True, end_key=None, filter_=None, + limit=None): + """Creates a request to read rows in a table. + + :type table_name: str + :param table_name: The name of the table to read from. + + :type row_key: bytes + :param row_key: (Optional) The key of a specific row to read from. + + :type start_key: bytes + :param start_key: (Optional) The beginning of a range of row keys to + read from. The range will include ``start_key``. If + left empty, will be interpreted as the empty string. + + :type end_key: bytes + :param end_key: (Optional) The end of a range of row keys to read from. + The range will not include ``end_key``. If left empty, + will be interpreted as an infinite string. + + :type filter_: :class:`.RowFilter` + :param filter_: (Optional) The filter to apply to the contents of the + specified row(s). If unset, reads the entire table. + + :type limit: int + :param limit: (Optional) The read will terminate after committing to N + rows' worth of results. The default (zero) is to return + all results. + + :rtype: :class:`data_messages_v2_pb2.ReadRowsRequest` + :returns: The ``ReadRowsRequest`` protobuf corresponding to the inputs. + :raises: :class:`ValueError ` if both + ``row_key`` and one of ``start_key`` and ``end_key`` are set + """ + request_kwargs = {'table_name': table_name} + if (row_key is not None and + (start_key is not None or end_key is not None)): + raise ValueError('Row key and row range cannot be ' + 'set simultaneously') + range_kwargs = {} + if start_key is not None or end_key is not None: + if start_key is not None: + if start_key_closed: + range_kwargs['start_key_closed'] = _to_bytes(start_key) + else: + range_kwargs['start_key_open'] = _to_bytes(start_key) + if end_key is not None: + range_kwargs['end_key_open'] = _to_bytes(end_key) + if filter_ is not None: + request_kwargs['filter'] = filter_.to_pb() + if limit is not None: + request_kwargs['rows_limit'] = limit + + message = data_messages_v2_pb2.ReadRowsRequest(**request_kwargs) + + if row_key is not None: + message.rows.row_keys.append(_to_bytes(row_key)) + + if range_kwargs: + message.rows.row_ranges.add(**range_kwargs) + + return message diff --git a/bigtable/google/cloud/bigtable/row_data.py b/bigtable/google/cloud/bigtable/row_data.py index 60fc1f0ef1e8..d3c70d431e29 100644 --- a/bigtable/google/cloud/bigtable/row_data.py +++ b/bigtable/google/cloud/bigtable/row_data.py @@ -274,6 +274,9 @@ def consume_next(self): self._validate_chunk(chunk) + if ("ReadRowsIterator" in self._response_iterator.__class__.__name__): + self._response_iterator.set_start_key(chunk.row_key) + if chunk.reset_row: row = self._row = None cell = self._cell = self._previous_cell = None diff --git a/bigtable/google/cloud/bigtable/table.py b/bigtable/google/cloud/bigtable/table.py index 3fbd198d6b65..085cb4b2ba9f 100644 --- a/bigtable/google/cloud/bigtable/table.py +++ b/bigtable/google/cloud/bigtable/table.py @@ -13,8 +13,10 @@ # limitations under the License. """User friendly container for Google Cloud Bigtable Table.""" +from __future__ import absolute_import, division + +import six -from google.cloud._helpers import _to_bytes from google.cloud.bigtable._generated import ( bigtable_pb2 as data_messages_v2_pb2) from google.cloud.bigtable._generated import ( @@ -27,7 +29,29 @@ from google.cloud.bigtable.row import ConditionalRow from google.cloud.bigtable.row import DirectRow from google.cloud.bigtable.row_data import PartialRowsData - +from google.gax import RetryOptions, BackoffSettings +from google.cloud.bigtable.retry import ReadRowsIterator, _create_row_request +from grpc import StatusCode + +BACKOFF_SETTINGS = BackoffSettings( + initial_retry_delay_millis = 10, + retry_delay_multiplier = 2, + max_retry_delay_millis = 5000, + initial_rpc_timeout_millis = 10, + rpc_timeout_multiplier = 2, + max_rpc_timeout_millis = 1000, + total_timeout_millis = 5000 +) + +RETRY_OPTIONS = RetryOptions( + retry_codes = [ + StatusCode.DEADLINE_EXCEEDED, + StatusCode.ABORTED, + StatusCode.INTERNAL, + StatusCode.UNAVAILABLE + ], + backoff_settings = BACKOFF_SETTINGS +) class Table(object): """Representation of a Google Cloud Bigtable Table. @@ -268,13 +292,11 @@ def read_rows(self, start_key=None, end_key=None, limit=None, :returns: A :class:`.PartialRowsData` convenience wrapper for consuming the streamed results. """ - request_pb = _create_row_request( - self.name, start_key=start_key, end_key=end_key, filter_=filter_, - limit=limit) + client = self._instance._client - response_iterator = client._data_stub.ReadRows(request_pb) - # We expect an iterator of `data_messages_v2_pb2.ReadRowsResponse` - return PartialRowsData(response_iterator) + retrying_iterator = ReadRowsIterator(client, self.name, start_key, + end_key, filter_, limit, RETRY_OPTIONS) + return PartialRowsData(retrying_iterator) def sample_row_keys(self): """Read a sample of row keys in the table. @@ -312,64 +334,3 @@ def sample_row_keys(self): client = self._instance._client response_iterator = client._data_stub.SampleRowKeys(request_pb) return response_iterator - - -def _create_row_request(table_name, row_key=None, start_key=None, end_key=None, - filter_=None, limit=None): - """Creates a request to read rows in a table. - - :type table_name: str - :param table_name: The name of the table to read from. - - :type row_key: bytes - :param row_key: (Optional) The key of a specific row to read from. - - :type start_key: bytes - :param start_key: (Optional) The beginning of a range of row keys to - read from. The range will include ``start_key``. If - left empty, will be interpreted as the empty string. - - :type end_key: bytes - :param end_key: (Optional) The end of a range of row keys to read from. - The range will not include ``end_key``. If left empty, - will be interpreted as an infinite string. - - :type filter_: :class:`.RowFilter` - :param filter_: (Optional) The filter to apply to the contents of the - specified row(s). If unset, reads the entire table. - - :type limit: int - :param limit: (Optional) The read will terminate after committing to N - rows' worth of results. The default (zero) is to return - all results. - - :rtype: :class:`data_messages_v2_pb2.ReadRowsRequest` - :returns: The ``ReadRowsRequest`` protobuf corresponding to the inputs. - :raises: :class:`ValueError ` if both - ``row_key`` and one of ``start_key`` and ``end_key`` are set - """ - request_kwargs = {'table_name': table_name} - if (row_key is not None and - (start_key is not None or end_key is not None)): - raise ValueError('Row key and row range cannot be ' - 'set simultaneously') - range_kwargs = {} - if start_key is not None or end_key is not None: - if start_key is not None: - range_kwargs['start_key_closed'] = _to_bytes(start_key) - if end_key is not None: - range_kwargs['end_key_open'] = _to_bytes(end_key) - if filter_ is not None: - request_kwargs['filter'] = filter_.to_pb() - if limit is not None: - request_kwargs['rows_limit'] = limit - - message = data_messages_v2_pb2.ReadRowsRequest(**request_kwargs) - - if row_key is not None: - message.rows.row_keys.append(_to_bytes(row_key)) - - if range_kwargs: - message.rows.row_ranges.add(**range_kwargs) - - return message diff --git a/bigtable/unit_tests/retries b/bigtable/unit_tests/retries new file mode 100755 index 000000000000..0f333227e934 Binary files /dev/null and b/bigtable/unit_tests/retries differ diff --git a/bigtable/unit_tests/retry b/bigtable/unit_tests/retry new file mode 100755 index 000000000000..1b72d2e4b0ed Binary files /dev/null and b/bigtable/unit_tests/retry differ diff --git a/bigtable/unit_tests/retry_server.go b/bigtable/unit_tests/retry_server.go new file mode 100755 index 000000000000..ec5c9a902691 --- /dev/null +++ b/bigtable/unit_tests/retry_server.go @@ -0,0 +1,295 @@ +package main + +import ( + "golang.org/x/net/context" + "cloud.google.com/go/bigtable" + "cloud.google.com/go/bigtable/bttest" + "google.golang.org/grpc" + "google.golang.org/api/option" + btpb "google.golang.org/genproto/googleapis/bigtable/v2" + rpcpb "google.golang.org/genproto/googleapis/rpc/status" + "strings" + "flag" + "os" + "bufio" + "log" + "reflect" + "github.com/golang/protobuf/ptypes/wrappers" + "google.golang.org/grpc/codes" + "fmt" +) + +var ( + scriptFile = flag.String("script", "", "the file containing the script") + codeMap = make(map[string]codes.Code) + failed = false +) + +type serverScript struct { + actions []string + idx int +} + +func init() { + codes := []codes.Code{ + codes.OK, codes.Canceled, codes.Unknown, codes.InvalidArgument, + codes.DeadlineExceeded, codes.NotFound, codes.AlreadyExists, codes.PermissionDenied, + codes.Unauthenticated, codes.ResourceExhausted, codes.Unauthenticated, codes.ResourceExhausted, + codes.FailedPrecondition, codes.Aborted, codes.OutOfRange, codes.Unimplemented, + codes.Internal, codes.Unavailable, codes.DataLoss} + for _, code := range codes { + codeMap[code.String()] = code + } +} + +func (s *serverScript) serverAction() []string { + return s.nextAction("SERVER:") +} + +func (s *serverScript) expectAction() []string { + return s.nextAction("EXPECT:") +} + +func (s *serverScript) isFinished() bool { + return s.idx == len(s.actions) +} + +func (s *serverScript) nextAction(prefix string) []string { + if s.isFinished() { + return nil + } + + a := s.actions[s.idx] + if strings.HasPrefix(a, prefix) { + s.idx++ + return strings.Split(a, " ")[1:] + } + return nil +} + +func main() { + flag.Parse() + + var actions []string + if file, err := os.Open(*scriptFile); err == nil { + defer file.Close() + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + line := scanner.Text() + if len(strings.TrimSpace(line)) == 0 || + strings.HasPrefix(line, "#") || + strings.HasPrefix(line, "CLIENT:") { + // Comment + continue + } + actions = append(actions, strings.TrimSpace(line)) + } + + // check for errors + if err = scanner.Err(); err != nil { + log.Fatal(err) + } + } else { + log.Fatal(err) + } + + ctx := context.Background() + script := serverScript{actions:actions} + + // Create the interceptor that will do all of our work. + interceptor := func(srv interface{}, ss grpc.ServerStream, info *grpc.StreamServerInfo, handler grpc.StreamHandler) error { + if failed { + return grpc.Errorf(codes.Canceled, "The test has failed") + } + if strings.HasSuffix(info.FullMethod, "MutateRows") || strings.HasSuffix(info.FullMethod, "ReadRows") { + fmt.Printf("DEBUG: %v\n", info) + action := script.expectAction() + op := action[0] + fmt.Printf("Expect: %s\n", op) + switch op { + case "SCAN": + if !strings.HasSuffix(info.FullMethod, "ReadRows") { + fail("Expected %v, received call to %v", action, info.FullMethod) + } + validateScan(action, ss) + case "READ": + if !strings.HasSuffix(info.FullMethod, "ReadRows") { + fail("Expected %v, received call to %v", action, info.FullMethod) + } + validateRead(strings.Split(action[1], ","), ss) + case "WRITE": + if !strings.HasSuffix(info.FullMethod, "MutateRows") { + fail("Expected %v, received call to %v", action, info.FullMethod) + } + validateWrite(strings.Split(action[1], ","), ss) + } + + for { + action = script.serverAction() + if action == nil { + break + } + log.Printf("Action: %s\n", action) + switch action[0] { + case "READ_RESPONSE": + writeReadRowsResponse(ss, strings.Split(action[1], ",")...) + case "WRITE_RESPONSE": + writeMutateRowsResponse(ss, strings.Split(action[1], ",")...) + case "ERROR": + return grpc.Errorf(codeMap[action[1]], "") + default: + fail("Invalid action during response: %v", action) + } + } + if script.isFinished() && !failed { + fmt.Println("PASS") + } + return nil + } else { + // Delegate to the handler for other operations (but there shouldn't be any) + return handler(ctx, ss) + } + } + + cleaner, err := setupFakeServer(grpc.StreamInterceptor(interceptor)) + if err != nil { + fmt.Println("FAIL") + log.Fatal(err) + } + defer cleaner() + select {} +} + +func validateScan(action []string, ss grpc.ServerStream) { + // Look for one or more ranges + var wantRanges []*btpb.RowRange + for _, r := range action[1:] { + startEnd := strings.Split(r[1:len(r)-1], ",") + rr := btpb.RowRange{} + if strings.HasPrefix(r, "[") { + rr.StartKey = &btpb.RowRange_StartKeyClosed{[]byte(startEnd[0])} + } else if strings.HasPrefix(r, "(") { + rr.StartKey = &btpb.RowRange_StartKeyOpen{[]byte(startEnd[0])} + } else { + fail("Invalid range: %v", r) + } + + if strings.HasSuffix(r, "]") { + rr.EndKey = &btpb.RowRange_EndKeyClosed{[]byte(startEnd[1])} + } else if strings.HasSuffix(r, ")") { + rr.EndKey = &btpb.RowRange_EndKeyOpen{[]byte(startEnd[1])} + } else { + fail("Invalid range: %v", r) + } + wantRanges = append(wantRanges, &rr) + } + + req := new(btpb.ReadRowsRequest) + ss.RecvMsg(req) + + want := &btpb.RowSet{RowRanges: wantRanges} + if !reflect.DeepEqual(want, req.Rows) { + fail("Invalid scan. got: %v\nwant: %v\n",req.Rows, want) + } +} + +func validateWrite(keys []string, ss grpc.ServerStream) { + want := make([][]byte, len(keys)) + for i, row := range keys { + want[i] = []byte(row) + } + + req := new(btpb.MutateRowsRequest) + ss.RecvMsg(req) + + var got [][]byte + for _, entry := range req.Entries { + got = append(got, entry.RowKey) + } + + if !reflect.DeepEqual(got, want) { + fail("Invalid write. got: %v\nwant: %v\n", got, want) + } +} + +func validateRead(keys []string, ss grpc.ServerStream) { + keyBytes := make([][]byte, len(keys)) + for i, row := range keys { + keyBytes[i] = []byte(row) + } + want := &btpb.RowSet{RowKeys:keyBytes} + + req := new(btpb.ReadRowsRequest) + ss.RecvMsg(req) + + if !reflect.DeepEqual(want, req.Rows) { + fail("Invalid read. got: %v\nwant: %v\n", req.Rows, want) + } +} + +func writeReadRowsResponse(ss grpc.ServerStream, rowKeys ...string) error { + var chunks []*btpb.ReadRowsResponse_CellChunk + for _, key := range rowKeys { + chunks = append(chunks, &btpb.ReadRowsResponse_CellChunk{ + RowKey: []byte(key), + FamilyName: &wrappers.StringValue{Value: "fm"}, + Qualifier: &wrappers.BytesValue{Value: []byte("col")}, + RowStatus: &btpb.ReadRowsResponse_CellChunk_CommitRow{CommitRow: true}, + }) + } + return ss.SendMsg(&btpb.ReadRowsResponse{Chunks: chunks}) +} + +func writeMutateRowsResponse(ss grpc.ServerStream, codes ...string) error { + res := &btpb.MutateRowsResponse{Entries: make([]*btpb.MutateRowsResponse_Entry, len(codes))} + for i, code := range codes { + res.Entries[i] = &btpb.MutateRowsResponse_Entry{ + Index: int64(i), + Status: &rpcpb.Status{Code: int32(codeMap[code]), Message: ""}, + } + } + return ss.SendMsg(res) +} + +func fail(format string, v ...interface{}) { + log.Printf(format, v...) + fmt.Println("FAIL") + failed = true +} + +func setupFakeServer(opt ...grpc.ServerOption) (cleanup func(), err error) { + srv, err := bttest.NewServer("127.0.0.1:", opt...) + if err != nil { + return nil, err + } + conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure()) + if err != nil { + return nil, err + } + + client, err := bigtable.NewClient(context.Background(), "client", "instance", option.WithGRPCConn(conn)) + if err != nil { + return nil, err + } + + adminClient, err := bigtable.NewAdminClient(context.Background(), "client", "instance", option.WithGRPCConn(conn)) + if err != nil { + return nil, err + } + if err := adminClient.CreateTable(context.Background(), "table"); err != nil { + return nil, err + } + if err := adminClient.CreateColumnFamily(context.Background(), "table", "cf"); err != nil { + return nil, err + } + + fmt.Println(srv.Addr) + + cleanupFunc := func() { + adminClient.Close() + client.Close() + srv.Close() + } + return cleanupFunc, nil +} diff --git a/bigtable/unit_tests/retry_test.txt b/bigtable/unit_tests/retry_test.txt new file mode 100644 index 000000000000..863662e897ba --- /dev/null +++ b/bigtable/unit_tests/retry_test.txt @@ -0,0 +1,38 @@ +# This retry script is processed by the retry server and the client under test. +# Client tests should parse any command beginning with "CLIENT:", send the corresponding RPC +# to the retry server and expect a valid response. +# "EXPECT" commands indicate the call the server is expecting the client to send. +# +# The retry server has one table named "table" that should be used for testing. +# There are three types of commands supported: +# READ +# Expect the corresponding rows to be returned with arbitrary values. +# SCAN ... +# Ranges are expressed as an interval with either open or closed start and end, +# such as [1,3) for "1,2" or (1, 3] for "2,3". +# WRITE +# All writes should succeed eventually. Value payload is ignored. +# The server writes PASS or FAIL on a line by itself to STDOUT depending on the result of the test. +# All other server output should be ignored. + +# Echo same scan back after immediate error +CLIENT: SCAN [r1,r3) r1,r2 +EXPECT: SCAN [r1,r3) +SERVER: ERROR Unavailable +EXPECT: SCAN [r1,r3) +SERVER: READ_RESPONSE r1,r2 + +# Retry scans with open interval starting at the least read row key. +# Instead of using open intervals for retry ranges, '\x00' can be +# appended to the last received row key and sent in a closed interval. +CLIENT: SCAN [r1,r9) r1,r2,r3,r4,r5,r6,r7,r8 +EXPECT: SCAN [r1,r9) +SERVER: READ_RESPONSE r1,r2,r3,r4 +SERVER: ERROR Unavailable +EXPECT: SCAN (r4,r9) +SERVER: ERROR Unavailable +EXPECT: SCAN (r4,r9) +SERVER: READ_RESPONSE r5,r6,r7 +SERVER: ERROR Unavailable +EXPECT: SCAN (r7,r9) +SERVER: READ_RESPONSE r8 diff --git a/bigtable/unit_tests/test_retry.py b/bigtable/unit_tests/test_retry.py new file mode 100644 index 000000000000..e79e460a9b43 --- /dev/null +++ b/bigtable/unit_tests/test_retry.py @@ -0,0 +1,54 @@ +import unittest +import subprocess, os, sys + +from google.cloud.bigtable.client import Client +from google.cloud.bigtable.instance import Instance +from google.cloud.bigtable.table import Table + +class TestRetry(unittest.TestCase): + + TEST_SCRIPT = "unit_tests/retry_test.txt" + + def test_retry(self): + table, server = self.connect_to_server() + f = open(self.TEST_SCRIPT, 'r') + for line in f.readlines(): + if line.startswith("CLIENT:"): + self.process_line(table, line) + server.kill() + + def process_line(self, table, line): + chunks = line.split(" ") + op = chunks[1] + if (op == "READ"): + self.process_read(table, chunks[2]) + elif (op == "WRITE"): + self.process_write(table, chunks[2]) + elif (op == "SCAN"): + self.process_scan(table, chunks[2], chunks[3]) + + def process_read(self, table, payload): + pass + + def process_write(self, table, payload): + pass + + def process_scan(self, table, range, ids): + range_chunks = range.split(",") + range_open = range_chunks[0].lstrip("[") + range_close = range_chunks[1].rstrip(")") + rows = table.read_rows(range_open, range_close) + rows.consume_all() + + def connect_to_server(self): + server = subprocess.Popen( + ['./unit_tests/retry', '--script=' + self.TEST_SCRIPT], + stdin=subprocess.PIPE, stdout=subprocess.PIPE, + ) + + (endpoint, port) = server.stdout.readline().rstrip("\n").split(":") + os.environ["BIGTABLE_EMULATOR_HOST"] = endpoint + ":" + port + client = Client(project="client", admin=True) + instance = Instance("instance", client) + table = instance.table("table") + return (table, server) diff --git a/bigtable/unit_tests/test_table.py b/bigtable/unit_tests/test_table.py deleted file mode 100644 index 4ad6afe1596f..000000000000 --- a/bigtable/unit_tests/test_table.py +++ /dev/null @@ -1,599 +0,0 @@ -# Copyright 2015 Google Inc. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. - - -import unittest - - -class TestTable(unittest.TestCase): - - PROJECT_ID = 'project-id' - INSTANCE_ID = 'instance-id' - INSTANCE_NAME = ('projects/' + PROJECT_ID + '/instances/' + INSTANCE_ID) - TABLE_ID = 'table-id' - TABLE_NAME = INSTANCE_NAME + '/tables/' + TABLE_ID - ROW_KEY = b'row-key' - FAMILY_NAME = u'family' - QUALIFIER = b'qualifier' - TIMESTAMP_MICROS = 100 - VALUE = b'value' - - @staticmethod - def _get_target_class(): - from google.cloud.bigtable.table import Table - - return Table - - def _make_one(self, *args, **kwargs): - return self._get_target_class()(*args, **kwargs) - - def test_constructor(self): - table_id = 'table-id' - instance = object() - - table = self._make_one(table_id, instance) - self.assertEqual(table.table_id, table_id) - self.assertIs(table._instance, instance) - - def test_name_property(self): - table_id = 'table-id' - instance_name = 'instance_name' - - instance = _Instance(instance_name) - table = self._make_one(table_id, instance) - expected_name = instance_name + '/tables/' + table_id - self.assertEqual(table.name, expected_name) - - def test_column_family_factory(self): - from google.cloud.bigtable.column_family import ColumnFamily - - table_id = 'table-id' - gc_rule = object() - table = self._make_one(table_id, None) - column_family_id = 'column_family_id' - column_family = table.column_family(column_family_id, gc_rule=gc_rule) - - self.assertIsInstance(column_family, ColumnFamily) - self.assertEqual(column_family.column_family_id, column_family_id) - self.assertIs(column_family.gc_rule, gc_rule) - self.assertEqual(column_family._table, table) - - def test_row_factory_direct(self): - from google.cloud.bigtable.row import DirectRow - - table_id = 'table-id' - table = self._make_one(table_id, None) - row_key = b'row_key' - row = table.row(row_key) - - self.assertIsInstance(row, DirectRow) - self.assertEqual(row._row_key, row_key) - self.assertEqual(row._table, table) - - def test_row_factory_conditional(self): - from google.cloud.bigtable.row import ConditionalRow - - table_id = 'table-id' - table = self._make_one(table_id, None) - row_key = b'row_key' - filter_ = object() - row = table.row(row_key, filter_=filter_) - - self.assertIsInstance(row, ConditionalRow) - self.assertEqual(row._row_key, row_key) - self.assertEqual(row._table, table) - - def test_row_factory_append(self): - from google.cloud.bigtable.row import AppendRow - - table_id = 'table-id' - table = self._make_one(table_id, None) - row_key = b'row_key' - row = table.row(row_key, append=True) - - self.assertIsInstance(row, AppendRow) - self.assertEqual(row._row_key, row_key) - self.assertEqual(row._table, table) - - def test_row_factory_failure(self): - table = self._make_one(self.TABLE_ID, None) - with self.assertRaises(ValueError): - table.row(b'row_key', filter_=object(), append=True) - - def test___eq__(self): - instance = object() - table1 = self._make_one(self.TABLE_ID, instance) - table2 = self._make_one(self.TABLE_ID, instance) - self.assertEqual(table1, table2) - - def test___eq__type_differ(self): - table1 = self._make_one(self.TABLE_ID, None) - table2 = object() - self.assertNotEqual(table1, table2) - - def test___ne__same_value(self): - instance = object() - table1 = self._make_one(self.TABLE_ID, instance) - table2 = self._make_one(self.TABLE_ID, instance) - comparison_val = (table1 != table2) - self.assertFalse(comparison_val) - - def test___ne__(self): - table1 = self._make_one('table_id1', 'instance1') - table2 = self._make_one('table_id2', 'instance2') - self.assertNotEqual(table1, table2) - - def _create_test_helper(self, initial_split_keys, column_families=()): - from google.cloud._helpers import _to_bytes - from unit_tests._testing import _FakeStub - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - splits_pb = [ - _CreateTableRequestSplitPB(key=_to_bytes(key)) - for key in initial_split_keys or ()] - table_pb = None - if column_families: - table_pb = _TablePB() - for cf in column_families: - cf_pb = table_pb.column_families[cf.column_family_id] - if cf.gc_rule is not None: - cf_pb.gc_rule.MergeFrom(cf.gc_rule.to_pb()) - request_pb = _CreateTableRequestPB( - initial_splits=splits_pb, - parent=self.INSTANCE_NAME, - table_id=self.TABLE_ID, - table=table_pb, - ) - - # Create response_pb - response_pb = _TablePB() - - # Patch the stub used by the API method. - client._table_stub = stub = _FakeStub(response_pb) - - # Create expected_result. - expected_result = None # create() has no return value. - - # Perform the method and check the result. - result = table.create(initial_split_keys=initial_split_keys, - column_families=column_families) - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'CreateTable', - (request_pb,), - {}, - )]) - - def test_create(self): - initial_split_keys = None - self._create_test_helper(initial_split_keys) - - def test_create_with_split_keys(self): - initial_split_keys = [b's1', b's2'] - self._create_test_helper(initial_split_keys) - - def test_create_with_column_families(self): - from google.cloud.bigtable.column_family import ColumnFamily - from google.cloud.bigtable.column_family import MaxVersionsGCRule - - cf_id1 = 'col-fam-id1' - cf1 = ColumnFamily(cf_id1, None) - cf_id2 = 'col-fam-id2' - gc_rule = MaxVersionsGCRule(42) - cf2 = ColumnFamily(cf_id2, None, gc_rule=gc_rule) - - initial_split_keys = None - column_families = [cf1, cf2] - self._create_test_helper(initial_split_keys, - column_families=column_families) - - def _list_column_families_helper(self): - from unit_tests._testing import _FakeStub - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - request_pb = _GetTableRequestPB(name=self.TABLE_NAME) - - # Create response_pb - COLUMN_FAMILY_ID = 'foo' - column_family = _ColumnFamilyPB() - response_pb = _TablePB( - column_families={COLUMN_FAMILY_ID: column_family}, - ) - - # Patch the stub used by the API method. - client._table_stub = stub = _FakeStub(response_pb) - - # Create expected_result. - expected_result = { - COLUMN_FAMILY_ID: table.column_family(COLUMN_FAMILY_ID), - } - - # Perform the method and check the result. - result = table.list_column_families() - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'GetTable', - (request_pb,), - {}, - )]) - - def test_list_column_families(self): - self._list_column_families_helper() - - def test_delete(self): - from google.protobuf import empty_pb2 - from unit_tests._testing import _FakeStub - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - request_pb = _DeleteTableRequestPB(name=self.TABLE_NAME) - - # Create response_pb - response_pb = empty_pb2.Empty() - - # Patch the stub used by the API method. - client._table_stub = stub = _FakeStub(response_pb) - - # Create expected_result. - expected_result = None # delete() has no return value. - - # Perform the method and check the result. - result = table.delete() - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'DeleteTable', - (request_pb,), - {}, - )]) - - def _read_row_helper(self, chunks, expected_result): - from google.cloud._testing import _Monkey - from unit_tests._testing import _FakeStub - from google.cloud.bigtable import table as MUT - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - request_pb = object() # Returned by our mock. - mock_created = [] - - def mock_create_row_request(table_name, row_key, filter_): - mock_created.append((table_name, row_key, filter_)) - return request_pb - - # Create response_iterator - if chunks is None: - response_iterator = iter(()) # no responses at all - else: - response_pb = _ReadRowsResponsePB(chunks=chunks) - response_iterator = iter([response_pb]) - - # Patch the stub used by the API method. - client._data_stub = stub = _FakeStub(response_iterator) - - # Perform the method and check the result. - filter_obj = object() - with _Monkey(MUT, _create_row_request=mock_create_row_request): - result = table.read_row(self.ROW_KEY, filter_=filter_obj) - - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'ReadRows', - (request_pb,), - {}, - )]) - self.assertEqual(mock_created, - [(table.name, self.ROW_KEY, filter_obj)]) - - def test_read_row_miss_no__responses(self): - self._read_row_helper(None, None) - - def test_read_row_miss_no_chunks_in_response(self): - chunks = [] - self._read_row_helper(chunks, None) - - def test_read_row_complete(self): - from google.cloud.bigtable.row_data import Cell - from google.cloud.bigtable.row_data import PartialRowData - - chunk = _ReadRowsResponseCellChunkPB( - row_key=self.ROW_KEY, - family_name=self.FAMILY_NAME, - qualifier=self.QUALIFIER, - timestamp_micros=self.TIMESTAMP_MICROS, - value=self.VALUE, - commit_row=True, - ) - chunks = [chunk] - expected_result = PartialRowData(row_key=self.ROW_KEY) - family = expected_result._cells.setdefault(self.FAMILY_NAME, {}) - column = family.setdefault(self.QUALIFIER, []) - column.append(Cell.from_pb(chunk)) - self._read_row_helper(chunks, expected_result) - - def test_read_row_still_partial(self): - chunk = _ReadRowsResponseCellChunkPB( - row_key=self.ROW_KEY, - family_name=self.FAMILY_NAME, - qualifier=self.QUALIFIER, - timestamp_micros=self.TIMESTAMP_MICROS, - value=self.VALUE, - ) - # No "commit row". - chunks = [chunk] - with self.assertRaises(ValueError): - self._read_row_helper(chunks, None) - - def test_read_rows(self): - from google.cloud._testing import _Monkey - from unit_tests._testing import _FakeStub - from google.cloud.bigtable.row_data import PartialRowsData - from google.cloud.bigtable import table as MUT - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - request_pb = object() # Returned by our mock. - mock_created = [] - - def mock_create_row_request(table_name, **kwargs): - mock_created.append((table_name, kwargs)) - return request_pb - - # Create response_iterator - response_iterator = object() - - # Patch the stub used by the API method. - client._data_stub = stub = _FakeStub(response_iterator) - - # Create expected_result. - expected_result = PartialRowsData(response_iterator) - - # Perform the method and check the result. - start_key = b'start-key' - end_key = b'end-key' - filter_obj = object() - limit = 22 - with _Monkey(MUT, _create_row_request=mock_create_row_request): - result = table.read_rows( - start_key=start_key, end_key=end_key, filter_=filter_obj, - limit=limit) - - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'ReadRows', - (request_pb,), - {}, - )]) - created_kwargs = { - 'start_key': start_key, - 'end_key': end_key, - 'filter_': filter_obj, - 'limit': limit, - } - self.assertEqual(mock_created, [(table.name, created_kwargs)]) - - def test_sample_row_keys(self): - from unit_tests._testing import _FakeStub - - client = _Client() - instance = _Instance(self.INSTANCE_NAME, client=client) - table = self._make_one(self.TABLE_ID, instance) - - # Create request_pb - request_pb = _SampleRowKeysRequestPB(table_name=self.TABLE_NAME) - - # Create response_iterator - response_iterator = object() # Just passed to a mock. - - # Patch the stub used by the API method. - client._data_stub = stub = _FakeStub(response_iterator) - - # Create expected_result. - expected_result = response_iterator - - # Perform the method and check the result. - result = table.sample_row_keys() - self.assertEqual(result, expected_result) - self.assertEqual(stub.method_calls, [( - 'SampleRowKeys', - (request_pb,), - {}, - )]) - - -class Test__create_row_request(unittest.TestCase): - - def _call_fut(self, table_name, row_key=None, start_key=None, end_key=None, - filter_=None, limit=None): - from google.cloud.bigtable.table import _create_row_request - - return _create_row_request( - table_name, row_key=row_key, start_key=start_key, end_key=end_key, - filter_=filter_, limit=limit) - - def test_table_name_only(self): - table_name = 'table_name' - result = self._call_fut(table_name) - expected_result = _ReadRowsRequestPB( - table_name=table_name) - self.assertEqual(result, expected_result) - - def test_row_key_row_range_conflict(self): - with self.assertRaises(ValueError): - self._call_fut(None, row_key=object(), end_key=object()) - - def test_row_key(self): - table_name = 'table_name' - row_key = b'row_key' - result = self._call_fut(table_name, row_key=row_key) - expected_result = _ReadRowsRequestPB( - table_name=table_name, - ) - expected_result.rows.row_keys.append(row_key) - self.assertEqual(result, expected_result) - - def test_row_range_start_key(self): - table_name = 'table_name' - start_key = b'start_key' - result = self._call_fut(table_name, start_key=start_key) - expected_result = _ReadRowsRequestPB(table_name=table_name) - expected_result.rows.row_ranges.add(start_key_closed=start_key) - self.assertEqual(result, expected_result) - - def test_row_range_end_key(self): - table_name = 'table_name' - end_key = b'end_key' - result = self._call_fut(table_name, end_key=end_key) - expected_result = _ReadRowsRequestPB(table_name=table_name) - expected_result.rows.row_ranges.add(end_key_open=end_key) - self.assertEqual(result, expected_result) - - def test_row_range_both_keys(self): - table_name = 'table_name' - start_key = b'start_key' - end_key = b'end_key' - result = self._call_fut(table_name, start_key=start_key, - end_key=end_key) - expected_result = _ReadRowsRequestPB(table_name=table_name) - expected_result.rows.row_ranges.add( - start_key_closed=start_key, end_key_open=end_key) - self.assertEqual(result, expected_result) - - def test_with_filter(self): - from google.cloud.bigtable.row_filters import RowSampleFilter - - table_name = 'table_name' - row_filter = RowSampleFilter(0.33) - result = self._call_fut(table_name, filter_=row_filter) - expected_result = _ReadRowsRequestPB( - table_name=table_name, - filter=row_filter.to_pb(), - ) - self.assertEqual(result, expected_result) - - def test_with_limit(self): - table_name = 'table_name' - limit = 1337 - result = self._call_fut(table_name, limit=limit) - expected_result = _ReadRowsRequestPB( - table_name=table_name, - rows_limit=limit, - ) - self.assertEqual(result, expected_result) - - -def _CreateTableRequestPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_table_admin_pb2 as table_admin_v2_pb2) - - return table_admin_v2_pb2.CreateTableRequest(*args, **kw) - - -def _CreateTableRequestSplitPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_table_admin_pb2 as table_admin_v2_pb2) - - return table_admin_v2_pb2.CreateTableRequest.Split(*args, **kw) - - -def _DeleteTableRequestPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_table_admin_pb2 as table_admin_v2_pb2) - - return table_admin_v2_pb2.DeleteTableRequest(*args, **kw) - - -def _GetTableRequestPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_table_admin_pb2 as table_admin_v2_pb2) - - return table_admin_v2_pb2.GetTableRequest(*args, **kw) - - -def _ReadRowsRequestPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_pb2 as messages_v2_pb2) - - return messages_v2_pb2.ReadRowsRequest(*args, **kw) - - -def _ReadRowsResponseCellChunkPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_pb2 as messages_v2_pb2) - - family_name = kw.pop('family_name') - qualifier = kw.pop('qualifier') - message = messages_v2_pb2.ReadRowsResponse.CellChunk(*args, **kw) - message.family_name.value = family_name - message.qualifier.value = qualifier - return message - - -def _ReadRowsResponsePB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_pb2 as messages_v2_pb2) - - return messages_v2_pb2.ReadRowsResponse(*args, **kw) - - -def _SampleRowKeysRequestPB(*args, **kw): - from google.cloud.bigtable._generated import ( - bigtable_pb2 as messages_v2_pb2) - - return messages_v2_pb2.SampleRowKeysRequest(*args, **kw) - - -def _TablePB(*args, **kw): - from google.cloud.bigtable._generated import ( - table_pb2 as table_v2_pb2) - - return table_v2_pb2.Table(*args, **kw) - - -def _ColumnFamilyPB(*args, **kw): - from google.cloud.bigtable._generated import ( - table_pb2 as table_v2_pb2) - - return table_v2_pb2.ColumnFamily(*args, **kw) - - -class _Client(object): - - data_stub = None - instance_stub = None - operations_stub = None - table_stub = None - - -class _Instance(object): - - def __init__(self, name, client=None): - self.name = name - self._client = client