diff --git a/google/cloud/spanner_dbapi/__init__.py b/google/cloud/spanner_dbapi/__init__.py index 0bb37492db..7695c0058f 100644 --- a/google/cloud/spanner_dbapi/__init__.py +++ b/google/cloud/spanner_dbapi/__init__.py @@ -6,37 +6,39 @@ """Connection-based DB API for Cloud Spanner.""" -from google.cloud import spanner_v1 - -from .connection import Connection -from .exceptions import ( - DatabaseError, - DataError, - Error, - IntegrityError, - InterfaceError, - InternalError, - NotSupportedError, - OperationalError, - ProgrammingError, - Warning, -) -from .parse_utils import get_param_types -from .types import ( - BINARY, - DATETIME, - NUMBER, - ROWID, - STRING, - Binary, - Date, - DateFromTicks, - Time, - TimeFromTicks, - Timestamp, - TimestampFromTicks, -) -from .version import google_client_info +from google.cloud.spanner_dbapi.connection import Connection +from google.cloud.spanner_dbapi.connection import connect + +from google.cloud.spanner_dbapi.cursor import Cursor + +from google.cloud.spanner_dbapi.exceptions import DatabaseError +from google.cloud.spanner_dbapi.exceptions import DataError +from google.cloud.spanner_dbapi.exceptions import Error +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import InternalError +from google.cloud.spanner_dbapi.exceptions import NotSupportedError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError +from google.cloud.spanner_dbapi.exceptions import Warning + +from google.cloud.spanner_dbapi.parse_utils import get_param_types + +from google.cloud.spanner_dbapi.types import BINARY +from google.cloud.spanner_dbapi.types import DATETIME +from google.cloud.spanner_dbapi.types import NUMBER +from google.cloud.spanner_dbapi.types import ROWID +from google.cloud.spanner_dbapi.types import STRING +from google.cloud.spanner_dbapi.types import Binary +from google.cloud.spanner_dbapi.types import Date +from google.cloud.spanner_dbapi.types import DateFromTicks +from google.cloud.spanner_dbapi.types import Time +from google.cloud.spanner_dbapi.types import TimeFromTicks +from google.cloud.spanner_dbapi.types import Timestamp +from google.cloud.spanner_dbapi.types import TimestampStr +from google.cloud.spanner_dbapi.types import TimestampFromTicks + +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT apilevel = "2.0" # supports DP-API 2.0 level. paramstyle = "format" # ANSI C printf format codes, e.g. ...WHERE name=%s. @@ -48,66 +50,10 @@ threadsafety = 1 -def connect( - instance_id, - database_id, - project=None, - credentials=None, - pool=None, - user_agent=None, -): - """ - Create a connection to Cloud Spanner database. - - :type instance_id: :class:`str` - :param instance_id: ID of the instance to connect to. - - :type database_id: :class:`str` - :param database_id: The name of the database to connect to. - - :type project: :class:`str` - :param project: (Optional) The ID of the project which owns the - instances, tables and data. If not provided, will - attempt to determine from the environment. - - :type credentials: :class:`google.auth.credentials.Credentials` - :param credentials: (Optional) The authorization credentials to attach to requests. - These credentials identify this application to the service. - If none are specified, the client will attempt to ascertain - the credentials from the environment. - - :type pool: Concrete subclass of - :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. - :param pool: (Optional). Session pool to be used by database. - - :type user_agent: :class:`str` - :param user_agent: (Optional) User agent to be used with this connection requests. - - :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` - :returns: Connection object associated with the given Cloud Spanner resource. - - :raises: :class:`ValueError` in case of given instance/database - doesn't exist. - """ - client = spanner_v1.Client( - project=project, - credentials=credentials, - client_info=google_client_info(user_agent), - ) - - instance = client.instance(instance_id) - if not instance.exists(): - raise ValueError("instance '%s' does not exist." % instance_id) - - database = instance.database(database_id, pool=pool) - if not database.exists(): - raise ValueError("database '%s' does not exist." % database_id) - - return Connection(instance, database) - - __all__ = [ "Connection", + "connect", + "Cursor", "DatabaseError", "DataError", "Error", @@ -120,7 +66,6 @@ def connect( "Warning", "DEFAULT_USER_AGENT", "apilevel", - "connect", "paramstyle", "threadsafety", "get_param_types", diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py new file mode 100644 index 0000000000..f581fdebbd --- /dev/null +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -0,0 +1,159 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.parse_utils import parse_insert +from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner +from google.cloud.spanner_v1 import param_types + + +SQL_LIST_TABLES = """ + SELECT + t.table_name + FROM + information_schema.tables AS t + WHERE + t.table_catalog = '' and t.table_schema = '' + """ + +SQL_GET_TABLE_COLUMN_SCHEMA = """SELECT + COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE + FROM + INFORMATION_SCHEMA.COLUMNS + WHERE + TABLE_SCHEMA = '' + AND + TABLE_NAME = @table_name + """ + +# This table maps spanner_types to Spanner's data type sizes as per +# https://cloud.google.com/spanner/docs/data-types#allowable-types +# It is used to map `display_size` to a known type for Cursor.description +# after a row fetch. +# Since ResultMetadata +# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata +# does not send back the actual size, we have to lookup the respective size. +# Some fields' sizes are dependent upon the dynamic data hence aren't sent back +# by Cloud Spanner. +code_to_display_size = { + param_types.BOOL.code: 1, + param_types.DATE.code: 4, + param_types.FLOAT64.code: 8, + param_types.INT64.code: 8, + param_types.TIMESTAMP.code: 12, +} + + +def _execute_insert_heterogenous(transaction, sql_params_list): + for sql, params in sql_params_list: + sql, params = sql_pyformat_args_to_spanner(sql, params) + param_types = get_param_types(params) + res = transaction.execute_sql( + sql, params=params, param_types=param_types + ) + # TODO: File a bug with Cloud Spanner and the Python client maintainers + # about a lost commit when res isn't read from. + _ = list(res) + + +def _execute_insert_homogenous(transaction, parts): + # Perform an insert in one shot. + table = parts.get("table") + columns = parts.get("columns") + values = parts.get("values") + return transaction.insert(table, columns, values) + + +def handle_insert(connection, sql, params): + parts = parse_insert(sql, params) + + # The split between the two styles exists because: + # in the common case of multiple values being passed + # with simple pyformat arguments, + # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) + # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] + # we can take advantage of a single RPC with: + # transaction.insert(table, columns, values) + # instead of invoking: + # with transaction: + # for sql, params in sql_params_list: + # transaction.execute_sql(sql, params, param_types) + # which invokes more RPCs and is more costly. + + if parts.get("homogenous"): + # The common case of multiple values being passed in + # non-complex pyformat args and need to be uploaded in one RPC. + return connection.database.run_in_transaction( + _execute_insert_homogenous, parts + ) + else: + # All the other cases that are esoteric and need + # transaction.execute_sql + sql_params_list = parts.get("sql_params_list") + return connection.database.run_in_transaction( + _execute_insert_heterogenous, sql_params_list + ) + + +class ColumnInfo: + """Row column description object.""" + + def __init__( + self, + name, + type_code, + display_size=None, + internal_size=None, + precision=None, + scale=None, + null_ok=False, + ): + self.name = name + self.type_code = type_code + self.display_size = display_size + self.internal_size = internal_size + self.precision = precision + self.scale = scale + self.null_ok = null_ok + + self.fields = ( + self.name, + self.type_code, + self.display_size, + self.internal_size, + self.precision, + self.scale, + self.null_ok, + ) + + def __repr__(self): + return self.__str__() + + def __getitem__(self, index): + return self.fields[index] + + def __str__(self): + str_repr = ", ".join( + filter( + lambda part: part is not None, + [ + "name='%s'" % self.name, + "type_code=%d" % self.type_code, + "display_size=%d" % self.display_size + if self.display_size + else None, + "internal_size=%d" % self.internal_size + if self.internal_size + else None, + "precision='%s'" % self.precision + if self.precision + else None, + "scale='%s'" % self.scale if self.scale else None, + "null_ok='%s'" % self.null_ok if self.null_ok else None, + ], + ) + ) + return "ColumnInfo(%s)" % str_repr diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 8907e65c03..b572c8573b 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -4,23 +4,24 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -"""Cloud Spanner DB connection object.""" +"""DB-API Connection for the Google Cloud Spanner.""" -from collections import namedtuple import warnings -from google.cloud import spanner_v1 +from google.api_core.gapic_v1.client_info import ClientInfo +from google.cloud import spanner_v1 as spanner -from .cursor import Cursor -from .exceptions import InterfaceError +from google.cloud.spanner_dbapi.cursor import Cursor +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.version import DEFAULT_USER_AGENT +from google.cloud.spanner_dbapi.version import PY_VERSION -AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" -ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) +AUTOCOMMIT_MODE_WARNING = "This method is non-operational in autocommit mode" class Connection: - """Representation of a connection to a Cloud Spanner database. + """Representation of a DB-API connection to a Cloud Spanner database. You most likely don't need to instantiate `Connection` objects directly, use the `connect` module function instead. @@ -29,14 +30,14 @@ class Connection: :param instance: Cloud Spanner instance to connect to. :type database: :class:`~google.cloud.spanner_v1.database.Database` - :param database: Cloud Spanner database to connect to. + :param database: The database to which the connection is linked. """ def __init__(self, instance, database): self._instance = instance self._database = database - self._ddl_statements = [] + self._transaction = None self._session = None @@ -54,7 +55,8 @@ def autocommit(self): @autocommit.setter def autocommit(self, value): - """Change this connection autocommit mode. + """Change this connection autocommit mode. Setting this value to True + while a transaction is active will commit the current transaction. :type value: bool :param value: New autocommit mode state. @@ -126,104 +128,17 @@ def transaction_checkout(self): return self._transaction - def cursor(self): - self._raise_if_closed() - - return Cursor(self) - def _raise_if_closed(self): - """Raise an exception if this connection is closed. - - Helper to check the connection state before - running a SQL/DDL/DML query. + """Helper to check the connection state before running a query. + Raises an exception if this connection is closed. - :raises: :class:`InterfaceError` if this connection is closed. + :raises: :class:`InterfaceError`: if this connection is closed. """ if self.is_closed: raise InterfaceError("connection is already closed") - def __handle_update_ddl(self, ddl_statements): - """ - Run the list of Data Definition Language (DDL) statements on the underlying - database. Each DDL statement MUST NOT contain a semicolon. - Args: - ddl_statements: a list of DDL statements, each without a semicolon. - Returns: - google.api_core.operation.Operation.result() - """ - self._raise_if_closed() - # Synchronously wait on the operation's completion. - return self.database.update_ddl(ddl_statements).result() - - def read_snapshot(self): - self._raise_if_closed() - return self.database.snapshot() - - def in_transaction(self, fn, *args, **kwargs): - self._raise_if_closed() - return self.database.run_in_transaction(fn, *args, **kwargs) - - def append_ddl_statement(self, ddl_statement): - self._raise_if_closed() - self._ddl_statements.append(ddl_statement) - - def run_prior_DDL_statements(self): - self._raise_if_closed() - - if not self._ddl_statements: - return - - ddl_statements = self._ddl_statements - self._ddl_statements = [] - - return self.__handle_update_ddl(ddl_statements) - - def list_tables(self): - return self.run_sql_in_snapshot( - """ - SELECT - t.table_name - FROM - information_schema.tables AS t - WHERE - t.table_catalog = '' and t.table_schema = '' - """ - ) - - def run_sql_in_snapshot(self, sql, params=None, param_types=None): - # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions - # hence this method exists to circumvent that limit. - self.run_prior_DDL_statements() - - with self.database.snapshot() as snapshot: - res = snapshot.execute_sql( - sql, params=params, param_types=param_types - ) - return list(res) - - def get_table_column_schema(self, table_name): - rows = self.run_sql_in_snapshot( - """SELECT - COLUMN_NAME, IS_NULLABLE, SPANNER_TYPE - FROM - INFORMATION_SCHEMA.COLUMNS - WHERE - TABLE_SCHEMA = '' - AND - TABLE_NAME = @table_name""", - params={"table_name": table_name}, - param_types={"table_name": spanner_v1.param_types.STRING}, - ) - - column_details = {} - for column_name, is_nullable, spanner_type in rows: - column_details[column_name] = ColumnDetails( - null_ok=is_nullable == "YES", spanner_type=spanner_type - ) - return column_details - def close(self): - """Close this connection. + """Closes this connection. The connection will be unusable from this point forward. If the connection has an active transaction, it will be rolled back. @@ -238,24 +153,109 @@ def close(self): self.is_closed = True def commit(self): - """Commit all the pending transactions.""" - if self.autocommit: + """Commits any pending transaction to the database. + + This method is non-operational in autocommit mode. + """ + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: self._transaction.commit() self._release_session() def rollback(self): - """Rollback all the pending transactions.""" - if self.autocommit: + """Rolls back any pending transaction. + + This is a no-op if there is no active transaction or if the connection + is in autocommit mode. + """ + if self._autocommit: warnings.warn(AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2) elif self._transaction: self._transaction.rollback() self._release_session() + def cursor(self): + """Factory to create a DB-API Cursor.""" + self._raise_if_closed() + + return Cursor(self) + + def run_prior_DDL_statements(self): + self._raise_if_closed() + + if self._ddl_statements: + ddl_statements = self._ddl_statements + self._ddl_statements = [] + + return self.database.update_ddl(ddl_statements).result() + def __enter__(self): return self def __exit__(self, etype, value, traceback): self.commit() self.close() + + +def connect( + instance_id, + database_id, + project=None, + credentials=None, + pool=None, + user_agent=None, +): + """Creates a connection to a Google Cloud Spanner database. + + :type instance_id: str + :param instance_id: The ID of the instance to connect to. + + :type database_id: str + :param database_id: The ID of the database to connect to. + + :type project: str + :param project: (Optional) The ID of the project which owns the + instances, tables and data. If not provided, will + attempt to determine from the environment. + + :type credentials: :class:`~google.auth.credentials.Credentials` + :param credentials: (Optional) The authorization credentials to attach to + requests. These credentials identify this application + to the service. If none are specified, the client will + attempt to ascertain the credentials from the + environment. + + :type pool: Concrete subclass of + :class:`~google.cloud.spanner_v1.pool.AbstractSessionPool`. + :param pool: (Optional). Session pool to be used by database. + + :type user_agent: str + :param user_agent: (Optional) User agent to be used with this connection's + requests. + + :rtype: :class:`google.cloud.spanner_dbapi.connection.Connection` + :returns: Connection object associated with the given Google Cloud Spanner + resource. + + :raises: :class:`ValueError` in case of given instance/database + doesn't exist. + """ + + client_info = ClientInfo( + user_agent=user_agent or DEFAULT_USER_AGENT, python_version=PY_VERSION, + ) + + client = spanner.Client( + project=project, credentials=credentials, client_info=client_info, + ) + + instance = client.instance(instance_id) + if not instance.exists(): + raise ValueError("instance '%s' does not exist." % instance_id) + + database = instance.database(database_id, pool=pool) + if not database.exists(): + raise ValueError("database '%s' does not exist." % database_id) + + return Connection(instance, database) diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index 95eae50e1a..6997752a42 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -4,258 +4,98 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -"""Database cursor API.""" - -from google.api_core.exceptions import ( - AlreadyExists, - FailedPrecondition, - InternalServerError, - InvalidArgument, -) -from google.cloud.spanner_v1 import param_types - -from .exceptions import ( - IntegrityError, - InterfaceError, - OperationalError, - ProgrammingError, -) -from .parse_utils import ( - STMT_DDL, - STMT_INSERT, - STMT_NON_UPDATING, - classify_stmt, - ensure_where_clause, - get_param_types, - parse_insert, - sql_pyformat_args_to_spanner, -) -from .utils import PeekIterator +"""Database cursor for Google Cloud Spanner DB-API.""" -_UNSET_COUNT = -1 +from google.api_core.exceptions import AlreadyExists +from google.api_core.exceptions import FailedPrecondition +from google.api_core.exceptions import InternalServerError +from google.api_core.exceptions import InvalidArgument +from collections import namedtuple -# This table maps spanner_types to Spanner's data type sizes as per -# https://cloud.google.com/spanner/docs/data-types#allowable-types -# It is used to map `display_size` to a known type for Cursor.description -# after a row fetch. -# Since ResultMetadata -# https://cloud.google.com/spanner/docs/reference/rest/v1/ResultSetMetadata -# does not send back the actual size, we have to lookup the respective size. -# Some fields' sizes are dependent upon the dynamic data hence aren't sent back -# by Cloud Spanner. -code_to_display_size = { - param_types.BOOL.code: 1, - param_types.DATE.code: 4, - param_types.FLOAT64.code: 8, - param_types.INT64.code: 8, - param_types.TIMESTAMP.code: 12, -} - - -class Cursor: - """ - Database cursor to manage the context of a fetch operation. +from google.cloud import spanner_v1 as spanner - :type connection: :class:`spanner_dbapi.connection.Connection` - :param connection: Parent connection object for this Cursor. - """ +from google.cloud.spanner_dbapi.exceptions import IntegrityError +from google.cloud.spanner_dbapi.exceptions import InterfaceError +from google.cloud.spanner_dbapi.exceptions import OperationalError +from google.cloud.spanner_dbapi.exceptions import ProgrammingError - def __init__(self, connection): - self._itr = None - self._res = None - self._row_count = _UNSET_COUNT - self._connection = connection - self._is_closed = False - - # the number of rows to fetch at a time with fetchmany() - self.arraysize = 1 - - def execute(self, sql, args=None): - """ - Abstracts and implements execute SQL statements on Cloud Spanner. - Args: - sql: A SQL statement - *args: variadic argument list - **kwargs: key worded arguments - Returns: - None - """ - self._raise_if_closed() - - if not self._connection: - raise ProgrammingError("Cursor is not connected to the database") - - self._res = None - - # Classify whether this is a read-only SQL statement. - try: - classification = classify_stmt(sql) - - if classification == STMT_DDL: - self._connection.append_ddl_statement(sql) - return - - # For every other operation, we've got to ensure that - # any prior DDL statements were run. - self._run_prior_DDL_statements() +from google.cloud.spanner_dbapi import _helpers +from google.cloud.spanner_dbapi._helpers import ColumnInfo +from google.cloud.spanner_dbapi._helpers import code_to_display_size - if not self._connection.autocommit: - transaction = self._connection.transaction_checkout() +from google.cloud.spanner_dbapi import parse_utils +from google.cloud.spanner_dbapi.parse_utils import get_param_types +from google.cloud.spanner_dbapi.utils import PeekIterator - sql, params = sql_pyformat_args_to_spanner(sql, args) +_UNSET_COUNT = -1 - self._res = transaction.execute_sql( - sql, params, param_types=get_param_types(params) - ) - self._itr = PeekIterator(self._res) - return +ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) - if classification == STMT_NON_UPDATING: - self.__handle_DQL(sql, args or None) - elif classification == STMT_INSERT: - self.__handle_insert(sql, args or None) - else: - self.__handle_update(sql, args or None) - except (AlreadyExists, FailedPrecondition) as e: - raise IntegrityError(e.details if hasattr(e, "details") else e) - except InvalidArgument as e: - raise ProgrammingError(e.details if hasattr(e, "details") else e) - except InternalServerError as e: - raise OperationalError(e.details if hasattr(e, "details") else e) - def __handle_update(self, sql, params): - self._connection.in_transaction(self.__do_execute_update, sql, params) +class Cursor(object): + """Database cursor to manage the context of a fetch operation. - def __do_execute_update(self, transaction, sql, params, param_types=None): - sql = ensure_where_clause(sql) - sql, params = sql_pyformat_args_to_spanner(sql, params) + :type connection: :class:`~google.cloud.spanner_dbapi.connection.Connection` + :param connection: A DB-API connection to Google Cloud Spanner. + """ - res = transaction.execute_update( - sql, params=params, param_types=get_param_types(params) - ) + def __init__(self, connection): self._itr = None - if type(res) == int: - self._row_count = res - - return res - - def __handle_insert(self, sql, params): - parts = parse_insert(sql, params) - - # The split between the two styles exists because: - # in the common case of multiple values being passed - # with simple pyformat arguments, - # SQL: INSERT INTO T (f1, f2) VALUES (%s, %s, %s) - # Params: [(1, 2, 3, 4, 5, 6, 7, 8, 9, 10,)] - # we can take advantage of a single RPC with: - # transaction.insert(table, columns, values) - # instead of invoking: - # with transaction: - # for sql, params in sql_params_list: - # transaction.execute_sql(sql, params, param_types) - # which invokes more RPCs and is more costly. - - if parts.get("homogenous"): - # The common case of multiple values being passed in - # non-complex pyformat args and need to be uploaded in one RPC. - return self._connection.in_transaction( - self.__do_execute_insert_homogenous, parts - ) - else: - # All the other cases that are esoteric and need - # transaction.execute_sql - sql_params_list = parts.get("sql_params_list") - return self._connection.in_transaction( - self.__do_execute_insert_heterogenous, sql_params_list - ) + self._result_set = None + self._row_count = _UNSET_COUNT + self.connection = connection + self._is_closed = False - def __do_execute_insert_heterogenous(self, transaction, sql_params_list): - for sql, params in sql_params_list: - sql, params = sql_pyformat_args_to_spanner(sql, params) - param_types = get_param_types(params) - res = transaction.execute_sql( - sql, params=params, param_types=param_types - ) - # TODO: File a bug with Cloud Spanner and the Python client maintainers - # about a lost commit when res isn't read from. - _ = list(res) - - def __do_execute_insert_homogenous(self, transaction, parts): - # Perform an insert in one shot. - table = parts.get("table") - columns = parts.get("columns") - values = parts.get("values") - return transaction.insert(table, columns, values) - - def __handle_DQL(self, sql, params): - with self._connection.read_snapshot() as snapshot: - # Reference - # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql - sql, params = sql_pyformat_args_to_spanner(sql, params) - res = snapshot.execute_sql( - sql, params=params, param_types=get_param_types(params) - ) - if type(res) == int: - self._row_count = res - self._itr = None - else: - # Immediately using: - # iter(response) - # here, because this Spanner API doesn't provide - # easy mechanisms to detect when only a single item - # is returned or many, yet mixing results that - # are for .fetchone() with those that would result in - # many items returns a RuntimeError if .fetchone() is - # invoked and vice versa. - self._res = res - # Read the first element so that StreamedResult can - # return the metadata after a DQL statement. See issue #155. - self._itr = PeekIterator(self._res) - # Unfortunately, Spanner doesn't seem to send back - # information about the number of rows available. - self._row_count = _UNSET_COUNT + # the number of rows to fetch at a time with fetchmany() + self.arraysize = 1 - def __enter__(self): - return self + @property + def is_closed(self): + """The cursor close indicator. - def __exit__(self, etype, value, traceback): - self.close() + :rtype: bool + :returns: True if the cursor or the parent connection is closed, + otherwise False. + """ + return self._is_closed or self.connection.is_closed @property def description(self): - if not (self._res and self._res.metadata): + """Read-only attribute containing a sequence of the following items: + + - ``name`` + - ``type_code`` + - ``display_size`` + - ``internal_size`` + - ``precision`` + - ``scale`` + - ``null_ok`` + """ + if not (self._result_set and self._result_set.metadata): return None - row_type = self._res.metadata.row_type + row_type = self._result_set.metadata.row_type columns = [] + for field in row_type.fields: - columns.append( - ColumnInfo( - name=field.name, - type_code=field.type.code, - # Size of the SQL type of the column. - display_size=code_to_display_size.get(field.type.code), - # Client perceived size of the column. - internal_size=field.ByteSize(), - ) + column_info = ColumnInfo( + name=field.name, + type_code=field.type.code, + # Size of the SQL type of the column. + display_size=code_to_display_size.get(field.type.code), + # Client perceived size of the column. + internal_size=field.ByteSize(), ) + columns.append(column_info) + return tuple(columns) @property def rowcount(self): + """The number of rows produced by the last `.execute()`.""" return self._row_count - @property - def is_closed(self): - """The cursor close indicator. - - :rtype: :class:`bool` - :returns: True if this cursor or it's parent connection is closed, False - otherwise. - """ - return self._is_closed or self._connection.is_closed - def _raise_if_closed(self): """Raise an exception if this cursor is closed. @@ -266,42 +106,104 @@ def _raise_if_closed(self): :raises: :class:`InterfaceError` if this cursor is closed. """ if self.is_closed: - raise InterfaceError("cursor is already closed") + raise InterfaceError("Cursor and/or connection is already closed.") + + def callproc(self, procname, args=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() def close(self): - """Close this cursor. + """Closes this Cursor, making it unusable from this point forward.""" + self._is_closed = True + + def _do_execute_update(self, transaction, sql, params, param_types=None): + sql = parse_utils.ensure_where_clause(sql) + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + + result = transaction.execute_update( + sql, params=params, param_types=get_param_types(params) + ) + self._itr = None + if type(result) == int: + self._row_count = result + + return result + + def execute(self, sql, args=None): + """Prepares and executes a Spanner database operation. - The cursor will be unusable from this point forward. + :type sql: str + :param sql: A SQL query statement. + + :type args: list + :param args: Additional parameters to supplement the SQL query. """ - self._is_closed = True + if not self.connection: + raise ProgrammingError("Cursor is not connected to the database") + + self._raise_if_closed() + + self._result_set = None + + # Classify whether this is a read-only SQL statement. + try: + classification = parse_utils.classify_stmt(sql) + if classification == parse_utils.STMT_DDL: + self.connection._ddl_statements.append(sql) + return + + # For every other operation, we've got to ensure that + # any prior DDL statements were run. + # self._run_prior_DDL_statements() + self.connection.run_prior_DDL_statements() + + if not self.connection.autocommit: + transaction = self.connection.transaction_checkout() + + sql, params = parse_utils.sql_pyformat_args_to_spanner( + sql, args + ) + + self._result_set = transaction.execute_sql( + sql, params, param_types=get_param_types(params) + ) + self._itr = PeekIterator(self._result_set) + return + + if classification == parse_utils.STMT_NON_UPDATING: + self._handle_DQL(sql, args or None) + elif classification == parse_utils.STMT_INSERT: + _helpers.handle_insert(self.connection, sql, args or None) + else: + self.connection.database.run_in_transaction( + self._do_execute_update, sql, args or None + ) + except (AlreadyExists, FailedPrecondition) as e: + raise IntegrityError(e.details if hasattr(e, "details") else e) + except InvalidArgument as e: + raise ProgrammingError(e.details if hasattr(e, "details") else e) + except InternalServerError as e: + raise OperationalError(e.details if hasattr(e, "details") else e) def executemany(self, operation, seq_of_params): - """ - Execute the given SQL with every parameters set + """Execute the given SQL with every parameters set from the given sequence of parameters. - :type operation: :class:`str` + :type operation: str :param operation: SQL code to execute. - :type seq_of_params: :class:`list` - :param seq_of_params: Sequence of params to run the query with. + :type seq_of_params: list + :param seq_of_params: Sequence of additional parameters to run + the query with. """ self._raise_if_closed() for params in seq_of_params: self.execute(operation, params) - def __next__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return next(self._itr) - - def __iter__(self): - if self._itr is None: - raise ProgrammingError("no results to return") - return self._itr - def fetchone(self): + """Fetch the next row of a query result set, returning a single + sequence, or None when no more data is available.""" self._raise_if_closed() try: @@ -309,22 +211,15 @@ def fetchone(self): except StopIteration: return None - def fetchall(self): - self._raise_if_closed() - - return list(self.__iter__()) - def fetchmany(self, size=None): - """ - Fetch the next set of rows of a query result, returning a sequence of sequences. - An empty sequence is returned when no more rows are available. - - Args: - size: optional integer to determine the maximum number of results to fetch. + """Fetch the next set of rows of a query result, returning a sequence + of sequences. An empty sequence is returned when no more rows are available. + :type size: int + :param size: (Optional) The maximum number of results to fetch. - Raises: - Error if the previous call to .execute*() did not produce any result set + :raises InterfaceError: + if the previous call to .execute*() did not produce any result set or if no call was issued yet. """ self._raise_if_closed() @@ -341,85 +236,94 @@ def fetchmany(self, size=None): return items - @property - def lastrowid(self): - return None + def fetchall(self): + """Fetch all (remaining) rows of a query result, returning them as + a sequence of sequences. + """ + self._raise_if_closed() - def setinputsizes(sizes): - raise ProgrammingError("Unimplemented") + return list(self.__iter__()) - def setoutputsize(size, column=None): - raise ProgrammingError("Unimplemented") + def nextset(self): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() - def _run_prior_DDL_statements(self): - return self._connection.run_prior_DDL_statements() + def setinputsizes(self, sizes): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def setoutputsize(self, size, column=None): + """A no-op, raising an error if the cursor or connection is closed.""" + self._raise_if_closed() + + def _handle_DQL(self, sql, params): + with self.connection.database.snapshot() as snapshot: + # Reference + # https://googleapis.dev/python/spanner/latest/session-api.html#google.cloud.spanner_v1.session.Session.execute_sql + sql, params = parse_utils.sql_pyformat_args_to_spanner(sql, params) + res = snapshot.execute_sql( + sql, params=params, param_types=get_param_types(params) + ) + if type(res) == int: + self._row_count = res + self._itr = None + else: + # Immediately using: + # iter(response) + # here, because this Spanner API doesn't provide + # easy mechanisms to detect when only a single item + # is returned or many, yet mixing results that + # are for .fetchone() with those that would result in + # many items returns a RuntimeError if .fetchone() is + # invoked and vice versa. + self._result_set = res + # Read the first element so that the StreamedResultSet can + # return the metadata after a DQL statement. See issue #155. + self._itr = PeekIterator(self._result_set) + # Unfortunately, Spanner doesn't seem to send back + # information about the number of rows available. + self._row_count = _UNSET_COUNT + + def __enter__(self): + return self + + def __exit__(self, etype, value, traceback): + self.close() + + def __next__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return next(self._itr) + + def __iter__(self): + if self._itr is None: + raise ProgrammingError("no results to return") + return self._itr def list_tables(self): - return self._connection.list_tables() + return self.run_sql_in_snapshot(_helpers.SQL_LIST_TABLES) + + def run_sql_in_snapshot(self, sql, params=None, param_types=None): + # Some SQL e.g. for INFORMATION_SCHEMA cannot be run in read-write transactions + # hence this method exists to circumvent that limit. + self.connection.run_prior_DDL_statements() - def run_sql_in_snapshot(self, sql): - return self._connection.run_sql_in_snapshot(sql) + with self.connection.database.snapshot() as snapshot: + res = snapshot.execute_sql( + sql, params=params, param_types=param_types + ) + return list(res) def get_table_column_schema(self, table_name): - return self._connection.get_table_column_schema(table_name) - - -class ColumnInfo: - """Row column description object.""" - - def __init__( - self, - name, - type_code, - display_size=None, - internal_size=None, - precision=None, - scale=None, - null_ok=False, - ): - self.name = name - self.type_code = type_code - self.display_size = display_size - self.internal_size = internal_size - self.precision = precision - self.scale = scale - self.null_ok = null_ok - - self.fields = ( - self.name, - self.type_code, - self.display_size, - self.internal_size, - self.precision, - self.scale, - self.null_ok, + rows = self.run_sql_in_snapshot( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": spanner.param_types.STRING}, ) - def __repr__(self): - return self.__str__() - - def __getitem__(self, index): - return self.fields[index] - - def __str__(self): - str_repr = ", ".join( - filter( - lambda part: part is not None, - [ - "name='%s'" % self.name, - "type_code=%d" % self.type_code, - "display_size=%d" % self.display_size - if self.display_size - else None, - "internal_size=%d" % self.internal_size - if self.internal_size - else None, - "precision='%s'" % self.precision - if self.precision - else None, - "scale='%s'" % self.scale if self.scale else None, - "null_ok='%s'" % self.null_ok if self.null_ok else None, - ], + column_details = {} + for column_name, is_nullable, spanner_type in rows: + column_details[column_name] = ColumnDetails( + null_ok=is_nullable == "YES", spanner_type=spanner_type ) - ) - return "ColumnInfo(%s)" % str_repr + return column_details diff --git a/google/cloud/spanner_dbapi/parse_utils.py b/google/cloud/spanner_dbapi/parse_utils.py index d0e807435e..084eea315e 100644 --- a/google/cloud/spanner_dbapi/parse_utils.py +++ b/google/cloud/spanner_dbapi/parse_utils.py @@ -262,11 +262,18 @@ def parse_insert(insert_sql, params): if not params: # Case a) perhaps? # Check if any %s exists. - pyformat_str_count = after_values_sql.count("%s") - if pyformat_str_count > 0: - raise ProgrammingError( - 'no params yet there are %d "%s" tokens' % pyformat_str_count - ) + + # pyformat_str_count = after_values_sql.count("%s") + # if pyformat_str_count > 0: + # raise ProgrammingError( + # 'no params yet there are %d "%%s" tokens' % pyformat_str_count + # ) + for item in after_values_sql: + if item.count("%s") > 0: + raise ProgrammingError( + 'no params yet there are %d "%%s" tokens' + % item.count("%s") + ) insert_sql = sanitize_literals_for_upload(insert_sql) # Confirmed case of: diff --git a/google/cloud/spanner_dbapi/parser.py b/google/cloud/spanner_dbapi/parser.py index 755384b2c1..2fc0156b57 100644 --- a/google/cloud/spanner_dbapi/parser.py +++ b/google/cloud/spanner_dbapi/parser.py @@ -33,7 +33,7 @@ VALUES = "VALUES" -class func: +class func(object): def __init__(self, func_name, args): self.name = func_name self.args = args @@ -67,7 +67,7 @@ class terminal(str): pass -class a_args: +class a_args(object): def __init__(self, argv): self.argv = argv @@ -89,13 +89,11 @@ def __eq__(self, other): if type(self) != type(other): return False - s_len, o_len = len(self), len(other) - if s_len != o_len: + if len(self) != len(other): return False - for i, s_item in enumerate(self): - o_item = other[i] - if s_item != o_item: + for i, item in enumerate(self): + if item != other[i]: return False return True @@ -108,7 +106,7 @@ def homogenous(self): Return True if all the arguments are pyformat args and have the same number of arguments. """ - if not self.all_have_same_argc(): + if not self._is_equal_length(): return False for arg in self.argv: @@ -121,7 +119,7 @@ def homogenous(self): return False return True - def all_have_same_argc(self): + def _is_equal_length(self): """ Return False if all the arguments have the same length. """ @@ -200,7 +198,7 @@ def expect(word, token): # (%s, %s...) if not (word and word.startswith("(")): raise ProgrammingError( - "ARGS: supposed to begin with `(` in `%s`" % (word) + "ARGS: supposed to begin with `(` in `%s`" % word ) word = word[1:] @@ -226,7 +224,7 @@ def expect(word, token): if not (word and word.startswith(")")): raise ProgrammingError( - "ARGS: supposed to end with `)` in `%s`" % (word) + "ARGS: supposed to end with `)` in `%s`" % word ) word = word[1:] @@ -235,7 +233,7 @@ def expect(word, token): elif token == EXPR: if word == "%s": # Terminal symbol. - return "", (pyfmt_str) + return "", pyfmt_str # Otherwise we expect a function. return expect(word, FUNC) @@ -244,5 +242,5 @@ def expect(word, token): def as_values(values_stmt): - _, values = parse_values(values_stmt) - return values + _, _values = parse_values(values_stmt) + return _values diff --git a/google/cloud/spanner_dbapi/version.py b/google/cloud/spanner_dbapi/version.py index 563d1b4354..88d8f7cdaf 100644 --- a/google/cloud/spanner_dbapi/version.py +++ b/google/cloud/spanner_dbapi/version.py @@ -4,23 +4,8 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import sys - -from google.api_core.gapic_v1.client_info import ClientInfo +import platform +PY_VERSION = platform.python_version() VERSION = "2.2.0a1" DEFAULT_USER_AGENT = "django_spanner/" + VERSION - -vers = sys.version_info - - -def google_client_info(user_agent=None): - """ - Return a google.api_core.gapic_v1.client_info.ClientInfo - containg the user_agent and python_version for this library - """ - - return ClientInfo( - user_agent=user_agent or DEFAULT_USER_AGENT, - python_version="%d.%d.%d" % (vers.major, vers.minor, vers.micro or 0), - ) diff --git a/noxfile.py b/noxfile.py index a8e2ce58e8..dd7466f6e8 100644 --- a/noxfile.py +++ b/noxfile.py @@ -69,14 +69,14 @@ def default(session): session.run( "py.test", "--quiet", - "--cov=django_spanner", + # "--cov=django_spanner", "--cov=google.cloud", - "--cov=tests.spanner_dbapi", + "--cov=tests.unit", "--cov-append", "--cov-config=.coveragerc", "--cov-report=", - "--cov-fail-under=0", - os.path.join("tests", "spanner_dbapi"), + "--cov-fail-under=90", + os.path.join("tests", "unit"), *session.posargs ) diff --git a/tests/spanner_dbapi/test_cursor.py b/tests/spanner_dbapi/test_cursor.py deleted file mode 100644 index 673a95d3e5..0000000000 --- a/tests/spanner_dbapi/test_cursor.py +++ /dev/null @@ -1,154 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -"""Cursor() class unit tests.""" - -import unittest -from unittest import mock - -from google.cloud.spanner_dbapi import connect, InterfaceError -from google.cloud.spanner_dbapi.cursor import ColumnInfo - - -class TestCursor(unittest.TestCase): - def test_close(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) - - cursor.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") - - def test_connection_closed(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - self.assertFalse(cursor.is_closed) - - connection.close() - - self.assertTrue(cursor.is_closed) - with self.assertRaises(InterfaceError): - cursor.execute("SELECT * FROM database") - - def test_executemany_on_closed_cursor(self): - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - cursor.close() - - with self.assertRaises(InterfaceError): - cursor.executemany( - """SELECT * FROM table1 WHERE "col1" = @a1""", () - ) - - def test_executemany(self): - operation = """SELECT * FROM table1 WHERE "col1" = @a1""" - params_seq = ((1,), (2,)) - - with mock.patch( - "google.cloud.spanner_v1.instance.Instance.exists", - return_value=True, - ): - with mock.patch( - "google.cloud.spanner_v1.database.Database.exists", - return_value=True, - ): - connection = connect("test-instance", "test-database") - - cursor = connection.cursor() - with mock.patch( - "google.cloud.spanner_dbapi.cursor.Cursor.execute" - ) as execute_mock: - cursor.executemany(operation, params_seq) - - execute_mock.assert_has_calls( - (mock.call(operation, (1,)), mock.call(operation, (2,))) - ) - - -class TestColumns(unittest.TestCase): - def test_ctor(self): - name = "col-name" - type_code = 8 - display_size = 5 - internal_size = 10 - precision = 3 - scale = None - null_ok = False - - cols = ColumnInfo( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ) - - self.assertEqual(cols.name, name) - self.assertEqual(cols.type_code, type_code) - self.assertEqual(cols.display_size, display_size) - self.assertEqual(cols.internal_size, internal_size) - self.assertEqual(cols.precision, precision) - self.assertEqual(cols.scale, scale) - self.assertEqual(cols.null_ok, null_ok) - self.assertEqual( - cols.fields, - ( - name, - type_code, - display_size, - internal_size, - precision, - scale, - null_ok, - ), - ) - - def test___get_item__(self): - fields = ("col-name", 8, 5, 10, 3, None, False) - cols = ColumnInfo(*fields) - - for i in range(0, 7): - self.assertEqual(cols[i], fields[i]) - - def test___str__(self): - cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) - - self.assertEqual( - str(cols), - "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", - ) diff --git a/tests/spanner_dbapi/test_version.py b/tests/spanner_dbapi/test_version.py deleted file mode 100644 index 9dfed1f55f..0000000000 --- a/tests/spanner_dbapi/test_version.py +++ /dev/null @@ -1,36 +0,0 @@ -# Copyright 2020 Google LLC -# -# Use of this source code is governed by a BSD-style -# license that can be found in the LICENSE file or at -# https://developers.google.com/open-source/licenses/bsd - -import sys -from unittest import TestCase - -from google.api_core.gapic_v1.client_info import ClientInfo -from google.cloud.spanner_dbapi.version import ( - DEFAULT_USER_AGENT, - google_client_info, -) - -vers = sys.version_info - - -class VersionUtils(TestCase): - def test_google_client_info_default_useragent(self): - got = google_client_info().to_grpc_metadata() - want = ClientInfo( - user_agent=DEFAULT_USER_AGENT, - python_version="%d.%d.%d" - % (vers.major, vers.minor, vers.micro or 0), - ).to_grpc_metadata() - self.assertEqual(got, want) - - def test_google_client_info_custom_useragent(self): - got = google_client_info("custom-user-agent").to_grpc_metadata() - want = ClientInfo( - user_agent="custom-user-agent", - python_version="%d.%d.%d" - % (vers.major, vers.minor, vers.micro or 0), - ).to_grpc_metadata() - self.assertEqual(got, want) diff --git a/tox.ini b/tests/unit/__init__.py similarity index 100% rename from tox.ini rename to tests/unit/__init__.py diff --git a/tests/spanner_dbapi/__init__.py b/tests/unit/spanner_dbapi/__init__.py similarity index 100% rename from tests/spanner_dbapi/__init__.py rename to tests/unit/spanner_dbapi/__init__.py diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py new file mode 100644 index 0000000000..e5316d254e --- /dev/null +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -0,0 +1,130 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import unittest + +from unittest import mock + + +class TestHelpers(unittest.TestCase): + def test__execute_insert_heterogenous(self): + from google.cloud.spanner_dbapi import _helpers + + sql = "sql" + params = (sql, None) + with mock.patch( + "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", + return_value=params, + ) as mock_pyformat: + with mock.patch( + "google.cloud.spanner_dbapi._helpers.get_param_types", + return_value=None, + ) as mock_param_types: + transaction = mock.MagicMock() + transaction.execute_sql = mock_execute = mock.MagicMock() + _helpers._execute_insert_heterogenous(transaction, [params]) + + mock_pyformat.assert_called_once_with(params[0], params[1]) + mock_param_types.assert_called_once_with(None) + mock_execute.assert_called_once_with( + sql, params=None, param_types=None + ) + + def test__execute_insert_homogenous(self): + from google.cloud.spanner_dbapi import _helpers + + transaction = mock.MagicMock() + transaction.insert = mock.MagicMock() + parts = mock.MagicMock() + parts.get = mock.MagicMock(return_value=0) + + _helpers._execute_insert_homogenous(transaction, parts) + transaction.insert.assert_called_once_with(0, 0, 0) + + def test_handle_insert(self): + from google.cloud.spanner_dbapi import _helpers + + connection = mock.MagicMock() + connection.database.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + parts = mock.MagicMock() + with mock.patch( + "google.cloud.spanner_dbapi._helpers.parse_insert", + return_value=parts, + ): + parts.get = mock.MagicMock(return_value=True) + mock_run_in.return_value = 0 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 0) + + parts.get = mock.MagicMock(return_value=False) + mock_run_in.return_value = 1 + result = _helpers.handle_insert(connection, sql, None) + self.assertEqual(result, 1) + + +class TestColumnInfo(unittest.TestCase): + def test_ctor(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + name = "col-name" + type_code = 8 + display_size = 5 + internal_size = 10 + precision = 3 + scale = None + null_ok = False + + cols = ColumnInfo( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ) + + self.assertEqual(cols.name, name) + self.assertEqual(cols.type_code, type_code) + self.assertEqual(cols.display_size, display_size) + self.assertEqual(cols.internal_size, internal_size) + self.assertEqual(cols.precision, precision) + self.assertEqual(cols.scale, scale) + self.assertEqual(cols.null_ok, null_ok) + self.assertEqual( + cols.fields, + ( + name, + type_code, + display_size, + internal_size, + precision, + scale, + null_ok, + ), + ) + + def test___get_item__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + fields = ("col-name", 8, 5, 10, 3, None, False) + cols = ColumnInfo(*fields) + + for i in range(0, 7): + self.assertEqual(cols[i], fields[i]) + + def test___str__(self): + from google.cloud.spanner_dbapi.cursor import ColumnInfo + + cols = ColumnInfo("col-name", 8, None, 10, 3, None, False) + + self.assertEqual( + str(cols), + "ColumnInfo(name='col-name', type_code=8, internal_size=10, precision='3')", + ) diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py new file mode 100644 index 0000000000..d545472c57 --- /dev/null +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -0,0 +1,318 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cloud Spanner DB-API Connection class unit tests.""" + +import unittest +import warnings + +from unittest import mock + + +def _make_credentials(): + from google.auth import credentials + + class _CredentialsWithScopes(credentials.Credentials, credentials.Scoped): + pass + + return mock.Mock(spec=_CredentialsWithScopes) + + +class TestConnection(unittest.TestCase): + + PROJECT = "test-project" + INSTANCE = "test-instance" + DATABASE = "test-database" + USER_AGENT = "user-agent" + CREDENTIALS = _make_credentials() + + def _get_client_info(self): + from google.api_core.gapic_v1.client_info import ClientInfo + + return ClientInfo(user_agent=self.USER_AGENT) + + def _make_connection(self): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_v1.instance import Instance + + # We don't need a real Client object to test the constructor + instance = Instance(self.INSTANCE, client=None) + database = instance.database(self.DATABASE) + return Connection(instance, database) + + def test_property_autocommit_setter(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = True + mock_commit.assert_called_once_with() + self.assertEqual(connection._autocommit, True) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection.commit" + ) as mock_commit: + connection.autocommit = False + mock_commit.assert_not_called() + self.assertEqual(connection._autocommit, False) + + def test_property_database(self): + from google.cloud.spanner_v1.database import Database + + connection = self._make_connection() + self.assertIsInstance(connection.database, Database) + self.assertEqual(connection.database, connection._database) + + def test_property_instance(self): + from google.cloud.spanner_v1.instance import Instance + + connection = self._make_connection() + self.assertIsInstance(connection.instance, Instance) + self.assertEqual(connection.instance, connection._instance) + + def test__session_checkout(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.get = mock.MagicMock( + return_value="db_session_pool" + ) + connection = Connection(self.INSTANCE, mock_database) + + connection._session_checkout() + mock_database._pool.get.assert_called_once_with() + self.assertEqual(connection._session, "db_session_pool") + + connection._session = "db_session" + connection._session_checkout() + self.assertEqual(connection._session, "db_session") + + def test__release_session(self): + from google.cloud.spanner_dbapi import Connection + + with mock.patch( + "google.cloud.spanner_v1.database.Database", + ) as mock_database: + mock_database._pool = mock.MagicMock() + mock_database._pool.put = mock.MagicMock() + connection = Connection(self.INSTANCE, mock_database) + connection._session = "session" + + connection._release_session() + mock_database._pool.put.assert_called_once_with("session") + self.assertIsNone(connection._session) + + def test_transaction_checkout(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + connection._session_checkout = mock_checkout = mock.MagicMock( + autospec=True + ) + connection.transaction_checkout() + mock_checkout.assert_called_once_with() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + self.assertEqual(connection.transaction_checkout(), mock_transaction) + + connection._autocommit = True + self.assertIsNone(connection.transaction_checkout()) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertFalse(connection.is_closed) + connection.close() + self.assertTrue(connection.is_closed) + + with self.assertRaises(InterfaceError): + connection.cursor() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.committed = mock_transaction.rolled_back = False + mock_transaction.rollback = mock_rollback = mock.MagicMock() + connection.close() + mock_rollback.assert_called_once_with() + + @mock.patch.object(warnings, "warn") + def test_commit(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.commit = mock_commit = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.commit() + mock_commit.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.commit() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + @mock.patch.object(warnings, "warn") + def test_rollback(self, mock_warn): + from google.cloud.spanner_dbapi import Connection + from google.cloud.spanner_dbapi.connection import ( + AUTOCOMMIT_MODE_WARNING, + ) + + connection = Connection(self.INSTANCE, self.DATABASE) + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_release.assert_not_called() + + connection._transaction = mock_transaction = mock.MagicMock() + mock_transaction.rollback = mock_rollback = mock.MagicMock() + + with mock.patch( + "google.cloud.spanner_dbapi.connection.Connection._release_session" + ) as mock_release: + connection.rollback() + mock_rollback.assert_called_once_with() + mock_release.assert_called_once_with() + + connection._autocommit = True + connection.rollback() + mock_warn.assert_called_once_with( + AUTOCOMMIT_MODE_WARNING, UserWarning, stacklevel=2 + ) + + def test_run_prior_DDL_statements(self): + from google.cloud.spanner_dbapi import Connection, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.database.Database", autospec=True, + ) as mock_database: + connection = Connection(self.INSTANCE, mock_database) + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_not_called() + + ddl = ["ddl"] + connection._ddl_statements = ddl + + connection.run_prior_DDL_statements() + mock_database.update_ddl.assert_called_once_with(ddl) + + connection.is_closed = True + + with self.assertRaises(InterfaceError): + connection.run_prior_DDL_statements() + + def test_context(self): + from google.cloud.spanner_dbapi import Connection + + connection = Connection(self.INSTANCE, self.DATABASE) + with connection as conn: + self.assertEqual(conn, connection) + + self.assertTrue(connection.is_closed) + + def test_connect(self): + from google.cloud.spanner_dbapi import Connection, connect + + with mock.patch("google.cloud.spanner_v1.Client"): + with mock.patch( + "google.api_core.gapic_v1.client_info.ClientInfo", + return_value=self._get_client_info(), + ): + connection = connect( + self.INSTANCE, + self.DATABASE, + self.PROJECT, + self.CREDENTIALS, + self.USER_AGENT, + ) + self.assertIsInstance(connection, Connection) + + def test_connect_instance_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=False, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_connect_database_not_found(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=False, + ): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with self.assertRaises(ValueError): + connect("test-instance", "test-database") + + def test_default_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + + with mock.patch("google.cloud.spanner_v1.instance.Instance.database"): + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + self.assertIsNotNone(connection.database._pool) + + def test_sessions_pool(self): + from google.cloud.spanner_dbapi import connect + from google.cloud.spanner_v1.pool import FixedSizePool + + database_id = "test-database" + pool = FixedSizePool() + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.database" + ) as database_mock: + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + connect("test-instance", database_id, pool=pool) + database_mock.assert_called_once_with(database_id, pool=pool) diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py new file mode 100644 index 0000000000..09288df94e --- /dev/null +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -0,0 +1,460 @@ +# Copyright 2020 Google LLC +# +# Use of this source code is governed by a BSD-style +# license that can be found in the LICENSE file or at +# https://developers.google.com/open-source/licenses/bsd + +"""Cursor() class unit tests.""" + +import unittest + +from unittest import mock + + +class TestCursor(unittest.TestCase): + + INSTANCE = "test-instance" + DATABASE = "test-database" + + def _get_target_class(self): + from google.cloud.spanner_dbapi import Cursor + + return Cursor + + def _make_one(self, *args, **kwargs): + return self._get_target_class()(*args, **kwargs) + + def _make_connection(self, *args, **kwargs): + from google.cloud.spanner_dbapi import Connection + + return Connection(*args, **kwargs) + + def test_property_connection(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.connection, connection) + + def test_property_description(self): + from google.cloud.spanner_dbapi._helpers import ColumnInfo + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + self.assertIsNone(cursor.description) + cursor._result_set = res_set = mock.MagicMock() + res_set.metadata.row_type.fields = [mock.MagicMock()] + self.assertIsNotNone(cursor.description) + self.assertIsInstance(cursor.description[0], ColumnInfo) + + def test_property_rowcount(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + self.assertEqual(cursor.rowcount, _UNSET_COUNT) + + def test_callproc(self): + from google.cloud.spanner_dbapi.exceptions import InterfaceError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor._is_closed = True + with self.assertRaises(InterfaceError): + cursor.callproc(procname=None) + + def test_close(self): + from google.cloud.spanner_dbapi import connect, InterfaceError + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect(self.INSTANCE, self.DATABASE) + + cursor = connection.cursor() + self.assertFalse(cursor.is_closed) + + cursor.close() + + self.assertTrue(cursor.is_closed) + with self.assertRaises(InterfaceError): + cursor.execute("SELECT * FROM database") + + def test_do_execute_update(self): + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + transaction = mock.MagicMock() + + def run_helper(ret_value): + transaction.execute_update.return_value = ret_value + res = cursor._do_execute_update( + transaction=transaction, sql="sql", params=None, + ) + return res + + expected = "good" + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + expected = 1234 + self.assertEqual(run_helper(expected), expected) + self.assertEqual(cursor._row_count, expected) + + def test_execute_programming_error(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + cursor.connection = None + with self.assertRaises(ProgrammingError): + cursor.execute(sql="") + + def test_execute_attribute_error(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + with self.assertRaises(AttributeError): + cursor.execute(sql="") + + def test_execute_autocommit_off(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.connection._autocommit = False + cursor.connection.transaction_checkout = mock.MagicMock(autospec=True) + + cursor.execute("sql") + self.assertIsInstance(cursor._result_set, mock.MagicMock) + self.assertIsInstance(cursor._itr, PeekIterator) + + def test_execute_statement(self): + from google.cloud.spanner_dbapi import parse_utils + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_DDL, + ) as mock_classify_stmt: + sql = "sql" + cursor.execute(sql=sql) + mock_classify_stmt.assert_called_once_with(sql) + self.assertEqual(cursor.connection._ddl_statements, [sql]) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_NON_UPDATING, + ): + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor._handle_DQL", + return_value=parse_utils.STMT_NON_UPDATING, + ) as mock_handle_ddl: + connection.autocommit = True + sql = "sql" + cursor.execute(sql=sql) + mock_handle_ddl.assert_called_once_with(sql, None) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value=parse_utils.STMT_INSERT, + ): + with mock.patch( + "google.cloud.spanner_dbapi._helpers.handle_insert", + return_value=parse_utils.STMT_INSERT, + ) as mock_handle_insert: + sql = "sql" + cursor.execute(sql=sql) + mock_handle_insert.assert_called_once_with( + connection, sql, None + ) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + return_value="other_statement", + ): + cursor.connection._database = mock_db = mock.MagicMock() + mock_db.run_in_transaction = mock_run_in = mock.MagicMock() + sql = "sql" + cursor.execute(sql=sql) + mock_run_in.assert_called_once_with( + cursor._do_execute_update, sql, None + ) + + def test_execute_integrity_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import IntegrityError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.AlreadyExists("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.FailedPrecondition("message"), + ): + with self.assertRaises(IntegrityError): + cursor.execute(sql="sql") + + def test_execute_invalid_argument(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InvalidArgument("message"), + ): + with self.assertRaises(ProgrammingError): + cursor.execute(sql="sql") + + def test_execute_internal_server_error(self): + from google.api_core import exceptions + from google.cloud.spanner_dbapi.exceptions import OperationalError + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + + with mock.patch( + "google.cloud.spanner_dbapi.parse_utils.classify_stmt", + side_effect=exceptions.InternalServerError("message"), + ): + with self.assertRaises(OperationalError): + cursor.execute(sql="sql") + + def test_executemany_on_closed_cursor(self): + from google.cloud.spanner_dbapi import InterfaceError + from google.cloud.spanner_dbapi import connect + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + cursor.close() + + with self.assertRaises(InterfaceError): + cursor.executemany( + """SELECT * FROM table1 WHERE "col1" = @a1""", () + ) + + def test_executemany(self): + from google.cloud.spanner_dbapi import connect + + operation = """SELECT * FROM table1 WHERE "col1" = @a1""" + params_seq = ((1,), (2,)) + + with mock.patch( + "google.cloud.spanner_v1.instance.Instance.exists", + return_value=True, + ): + with mock.patch( + "google.cloud.spanner_v1.database.Database.exists", + return_value=True, + ): + connection = connect("test-instance", "test-database") + + cursor = connection.cursor() + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.execute" + ) as execute_mock: + cursor.executemany(operation, params_seq) + + execute_mock.assert_has_calls( + (mock.call(operation, (1,)), mock.call(operation, (2,))) + ) + + def test_fetchone(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [1, 2, 3] + cursor._itr = iter(lst) + for i in range(len(lst)): + self.assertEqual(cursor.fetchone(), lst[i]) + self.assertIsNone(cursor.fetchone()) + + def test_fetchmany(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + + self.assertEqual(cursor.fetchmany(), [lst[0]]) + + result = cursor.fetchmany(len(lst)) + self.assertEqual(result, lst[1:]) + + def test_fetchall(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + self.assertEqual(cursor.fetchall(), lst) + + def test_nextset(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.nextset() + + def test_setinputsizes(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setinputsizes(sizes=None) + + def test_setoutputsize(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + cursor = self._make_one(connection) + cursor.close() + with self.assertRaises(exceptions.InterfaceError): + cursor.setoutputsize(size=None) + + # def test_handle_insert(self): + # pass + # + # def test_do_execute_insert_heterogenous(self): + # pass + # + # def test_do_execute_insert_homogenous(self): + # pass + + def test_handle_dql(self): + from google.cloud.spanner_dbapi import utils + from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + mock_snapshot.execute_sql.return_value = int(0) + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._row_count, 0) + self.assertIsNone(cursor._itr) + + mock_snapshot.execute_sql.return_value = "0" + cursor._handle_DQL("sql", params=None) + self.assertEqual(cursor._result_set, "0") + self.assertIsInstance(cursor._itr, utils.PeekIterator) + self.assertEqual(cursor._row_count, _UNSET_COUNT) + + def test_context(self): + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with cursor as c: + self.assertEqual(c, cursor) + + self.assertTrue(c.is_closed) + + def test_next(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + cursor.__next__() + + lst = [(1,), (2,), (3,)] + cursor._itr = iter(lst) + i = 0 + for c in cursor._itr: + self.assertEqual(c, lst[i]) + i += 1 + + def test_iter(self): + from google.cloud.spanner_dbapi import exceptions + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + with self.assertRaises(exceptions.ProgrammingError): + _ = iter(cursor) + + iterator = iter([(1,), (2,), (3,)]) + cursor._itr = iterator + self.assertEqual(iter(cursor), iterator) + + def test_list_tables(self): + from google.cloud.spanner_dbapi import _helpers + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + table_list = ["table1", "table2", "table3"] + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=table_list, + ) as mock_run_sql: + cursor.list_tables() + mock_run_sql.assert_called_once_with(_helpers.SQL_LIST_TABLES) + + def test_run_sql_in_snapshot(self): + connection = self._make_connection(self.INSTANCE, mock.MagicMock()) + connection.database.snapshot.return_value.__enter__.return_value = ( + mock_snapshot + ) = mock.MagicMock() + cursor = self._make_one(connection) + + results = 1, 2, 3 + mock_snapshot.execute_sql.return_value = results + self.assertEqual(cursor.run_sql_in_snapshot("sql"), list(results)) + + def test_get_table_column_schema(self): + from google.cloud.spanner_dbapi.cursor import ColumnDetails + from google.cloud.spanner_dbapi import _helpers + from google.cloud.spanner_v1 import param_types + + connection = self._make_connection(self.INSTANCE, self.DATABASE) + cursor = self._make_one(connection) + + column_name = "column1" + is_nullable = "YES" + spanner_type = "spanner_type" + rows = [(column_name, is_nullable, spanner_type)] + expected = { + column_name: ColumnDetails( + null_ok=True, spanner_type=spanner_type, + ) + } + with mock.patch( + "google.cloud.spanner_dbapi.cursor.Cursor.run_sql_in_snapshot", + return_value=rows, + ) as mock_run_sql: + table_name = "table1" + result = cursor.get_table_column_schema(table_name=table_name) + mock_run_sql.assert_called_once_with( + sql=_helpers.SQL_GET_TABLE_COLUMN_SCHEMA, + params={"table_name": table_name}, + param_types={"table_name": param_types.STRING}, + ) + self.assertEqual(result, expected) diff --git a/tests/spanner_dbapi/test_globals.py b/tests/unit/spanner_dbapi/test_globals.py similarity index 67% rename from tests/spanner_dbapi/test_globals.py rename to tests/unit/spanner_dbapi/test_globals.py index 7c3e0396a9..3f8360e2ea 100644 --- a/tests/spanner_dbapi/test_globals.py +++ b/tests/unit/spanner_dbapi/test_globals.py @@ -4,13 +4,15 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase +import unittest -from google.cloud.spanner_dbapi import apilevel, paramstyle, threadsafety - -class DBAPIGlobalsTests(TestCase): +class TestDBAPIGlobals(unittest.TestCase): def test_apilevel(self): + from google.cloud.spanner_dbapi import apilevel + from google.cloud.spanner_dbapi import paramstyle + from google.cloud.spanner_dbapi import threadsafety + self.assertEqual(apilevel, "2.0", "We implement PEP-0249 version 2.0") self.assertEqual(paramstyle, "format", "Cloud Spanner uses @param") self.assertEqual( diff --git a/tests/spanner_dbapi/test_parse_utils.py b/tests/unit/spanner_dbapi/test_parse_utils.py similarity index 86% rename from tests/spanner_dbapi/test_parse_utils.py rename to tests/unit/spanner_dbapi/test_parse_utils.py index 815811ce21..1bd38c85eb 100644 --- a/tests/spanner_dbapi/test_parse_utils.py +++ b/tests/unit/spanner_dbapi/test_parse_utils.py @@ -4,33 +4,19 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import datetime -import decimal -from unittest import TestCase +import unittest from google.cloud.spanner_v1 import param_types -from google.cloud.spanner_dbapi.exceptions import Error, ProgrammingError -from google.cloud.spanner_dbapi.parse_utils import ( - STMT_DDL, - STMT_INSERT, - STMT_NON_UPDATING, - STMT_UPDATING, - DateStr, - TimestampStr, - cast_for_spanner, - classify_stmt, - ensure_where_clause, - escape_name, - get_param_types, - parse_insert, - rows_for_insert_or_update, - sql_pyformat_args_to_spanner, -) -from google.cloud.spanner_dbapi.utils import backtick_unicode - - -class ParseUtilsTests(TestCase): + + +class TestParseUtils(unittest.TestCase): def test_classify_stmt(self): + from google.cloud.spanner_dbapi.parse_utils import STMT_DDL + from google.cloud.spanner_dbapi.parse_utils import STMT_INSERT + from google.cloud.spanner_dbapi.parse_utils import STMT_NON_UPDATING + from google.cloud.spanner_dbapi.parse_utils import STMT_UPDATING + from google.cloud.spanner_dbapi.parse_utils import classify_stmt + cases = ( ("SELECT 1", STMT_NON_UPDATING), ("SELECT s.SongName FROM Songs AS s", STMT_NON_UPDATING), @@ -61,6 +47,12 @@ def test_classify_stmt(self): self.assertEqual(classify_stmt(query), want_class) def test_parse_insert(self): + from google.cloud.spanner_dbapi.parse_utils import parse_insert + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + + with self.assertRaises(ProgrammingError): + parse_insert("bad-sql", None) + cases = [ ( "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)", @@ -173,6 +165,10 @@ def test_parse_insert(self): ), ] + sql = "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s)" + with self.assertRaises(ProgrammingError): + parse_insert(sql, None) + for sql, params, want in cases: with self.subTest(sql=sql): got = parse_insert(sql, params) @@ -181,6 +177,9 @@ def test_parse_insert(self): ) def test_parse_insert_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import parse_insert + cases = [ ( "INSERT INTO django_migrations (app, name, applied) VALUES (%s, %s, %s), (%s, %s, %s)", @@ -202,12 +201,23 @@ def test_parse_insert_invalid(self): for sql, params, wantException in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - ProgrammingError, + exceptions.ProgrammingError, wantException, lambda: parse_insert(sql, params), ) def test_rows_for_insert_or_update(self): + from google.cloud.spanner_dbapi.parse_utils import ( + rows_for_insert_or_update, + ) + from google.cloud.spanner_dbapi.exceptions import Error + + with self.assertRaises(Error): + rows_for_insert_or_update([0], [[]]) + + with self.assertRaises(Error): + rows_for_insert_or_update([0], None, ["0", "%s"]) + cases = [ ( ["id", "app", "name"], @@ -255,6 +265,12 @@ def test_rows_for_insert_or_update(self): self.assertEqual(got, want) def test_sql_pyformat_args_to_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + cases = [ ( ( @@ -323,6 +339,11 @@ def test_sql_pyformat_args_to_spanner(self): ) def test_sql_pyformat_args_to_spanner_invalid(self): + from google.cloud.spanner_dbapi import exceptions + from google.cloud.spanner_dbapi.parse_utils import ( + sql_pyformat_args_to_spanner, + ) + cases = [ ( "SELECT * from t WHERE f1=%s, f2 = %s, f3=%s, extra=%s", @@ -332,12 +353,28 @@ def test_sql_pyformat_args_to_spanner_invalid(self): for sql, params in cases: with self.subTest(sql=sql): self.assertRaisesRegex( - Error, + exceptions.Error, "pyformat_args mismatch", lambda: sql_pyformat_args_to_spanner(sql, params), ) + def test_cast_for_spanner(self): + import decimal + + from google.cloud.spanner_dbapi.parse_utils import cast_for_spanner + + value = decimal.Decimal(3) + self.assertEqual(cast_for_spanner(value), float(3.0)) + self.assertEqual(cast_for_spanner(5), 5) + self.assertEqual(cast_for_spanner("string"), "string") + def test_get_param_types(self): + import datetime + + from google.cloud.spanner_dbapi.parse_utils import DateStr + from google.cloud.spanner_dbapi.parse_utils import TimestampStr + from google.cloud.spanner_dbapi.parse_utils import get_param_types + params = { "a1": 10, "b1": "string", @@ -365,15 +402,13 @@ def test_get_param_types(self): self.assertEqual(got_types, want_types) def test_get_param_types_none(self): - self.assertEqual(get_param_types(None), None) + from google.cloud.spanner_dbapi.parse_utils import get_param_types - def test_cast_for_spanner(self): - value = decimal.Decimal(3) - self.assertEqual(cast_for_spanner(value), float(3.0)) - self.assertEqual(cast_for_spanner(5), 5) - self.assertEqual(cast_for_spanner("string"), "string") + self.assertEqual(get_param_types(None), None) def test_ensure_where_clause(self): + from google.cloud.spanner_dbapi.parse_utils import ensure_where_clause + cases = [ ( "UPDATE a SET a.b=10 FROM articles a JOIN d c ON a.ai = c.ai WHERE c.ci = 1", @@ -404,6 +439,8 @@ def test_ensure_where_clause(self): self.assertEqual(got, want) def test_escape_name(self): + from google.cloud.spanner_dbapi.parse_utils import escape_name + cases = ( ("SELECT", "`SELECT`"), ("dashed-value", "`dashed-value`"), @@ -415,16 +452,3 @@ def test_escape_name(self): with self.subTest(name=name): got = escape_name(name) self.assertEqual(got, want) - - def test_backtick_unicode(self): - cases = [ - ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), - ("SELECT (1) as föö", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), - ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), - ("SELECT (1) as `föö", "SELECT (1) as `föö"), - ] - for sql, want in cases: - with self.subTest(sql=sql): - got = backtick_unicode(sql) - self.assertEqual(got, want) diff --git a/tests/spanner_dbapi/test_parser.py b/tests/unit/spanner_dbapi/test_parser.py similarity index 53% rename from tests/spanner_dbapi/test_parser.py rename to tests/unit/spanner_dbapi/test_parser.py index 9aecf38e42..d5baf9d824 100644 --- a/tests/spanner_dbapi/test_parser.py +++ b/tests/unit/spanner_dbapi/test_parser.py @@ -4,23 +4,19 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase - -from google.cloud.spanner_dbapi.exceptions import ProgrammingError -from google.cloud.spanner_dbapi.parser import ( - ARGS, - FUNC, - VALUES, - a_args, - expect, - func, - pyfmt_str, - values, -) - - -class ParserTests(TestCase): +import unittest + +from unittest import mock + + +class TestParser(unittest.TestCase): def test_func(self): + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + cases = [ ("_91())", ")", func("_91", a_args([]))), ("_a()", "", func("_a", a_args([]))), @@ -61,6 +57,10 @@ def test_func(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_func_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "FUNC: `` does not begin with `a-zA-z` nor a `_`"), ("91", "FUNC: `91` does not begin with `a-zA-z` nor a `_`"), @@ -76,7 +76,30 @@ def test_func_fail(self): ProgrammingError, wantException, lambda: expect(text, FUNC) ) + def test_func_eq(self): + from google.cloud.spanner_dbapi.parser import func + + func1 = func("func1", None) + func2 = func("func2", None) + self.assertFalse(func1 == object) + self.assertFalse(func1 == func2) + func2.name = func1.name + func1.args = 0 + func2.args = "0" + self.assertFalse(func1 == func2) + func1.args = [0] + func2.args = [0, 0] + self.assertFalse(func1 == func2) + func2.args = func1.args + self.assertTrue(func1 == func2) + def test_a_args(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + cases = [ ("()", "", a_args([])), ("(%s)", "", a_args([pyfmt_str])), @@ -102,6 +125,10 @@ def test_a_args(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_a_args_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "ARGS: supposed to begin with `\\(`"), ("(", "ARGS: supposed to end with `\\)`"), @@ -115,7 +142,80 @@ def test_a_args_fail(self): ProgrammingError, wantException, lambda: expect(text, ARGS) ) + def test_a_args_has_expr(self): + from google.cloud.spanner_dbapi.parser import a_args + + self.assertFalse(a_args([]).has_expr()) + self.assertTrue(a_args([[0]]).has_expr()) + + def test_a_args_eq(self): + from google.cloud.spanner_dbapi.parser import a_args + + a1 = a_args([0]) + self.assertFalse(a1 == object()) + a2 = a_args([0, 0]) + self.assertFalse(a1 == a2) + a1.argv = [0, 1] + self.assertFalse(a1 == a2) + a2.argv = [0, 1] + self.assertTrue(a1 == a2) + + def test_a_args_homogeneous(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertTrue(a_obj.homogenous()) + + a_obj = a_args([a_args([[object()]]) for _ in range(10)]) + self.assertFalse(a_obj.homogenous()) + + def test_a_args__is_equal_length(self): + from google.cloud.spanner_dbapi.parser import a_args + + a_obj = a_args([]) + self.assertTrue(a_obj._is_equal_length()) + + def test_values(self): + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import terminal + from google.cloud.spanner_dbapi.parser import values + + a_obj = a_args([a_args([terminal(10 ** i)]) for i in range(10)]) + self.assertEqual(str(values(a_obj)), "VALUES%s" % str(a_obj)) + + def test_expect(self): + from google.cloud.spanner_dbapi.parser import ARGS + from google.cloud.spanner_dbapi.parser import EXPR + from google.cloud.spanner_dbapi.parser import FUNC + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi import exceptions + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="ABC", token=ARGS) + with self.assertRaises(exceptions.ProgrammingError): + expect(word="(", token=ARGS) + + expected = "", pyfmt_str + self.assertEqual(expect("%s", EXPR), expected) + + expected = expect("function()", FUNC) + self.assertEqual(expect("function()", EXPR), expected) + + with self.assertRaises(exceptions.ProgrammingError): + expect(word="", token="ABC") + def test_expect_values(self): + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import a_args + from google.cloud.spanner_dbapi.parser import expect + from google.cloud.spanner_dbapi.parser import func + from google.cloud.spanner_dbapi.parser import pyfmt_str + from google.cloud.spanner_dbapi.parser import values + cases = [ ("VALUES ()", "", values([a_args([])])), ("VALUES", "", values([])), @@ -156,6 +256,10 @@ def test_expect_values(self): self.assertEqual(got_unconsumed, want_unconsumed) def test_expect_values_fail(self): + from google.cloud.spanner_dbapi.exceptions import ProgrammingError + from google.cloud.spanner_dbapi.parser import VALUES + from google.cloud.spanner_dbapi.parser import expect + cases = [ ("", "VALUES: `` does not start with VALUES"), ( @@ -172,3 +276,13 @@ def test_expect_values_fail(self): wantException, lambda: expect(text, VALUES), ) + + def test_as_values(self): + from google.cloud.spanner_dbapi.parser import as_values + + values = (1, 2) + with mock.patch( + "google.cloud.spanner_dbapi.parser.parse_values", + return_value=values, + ): + self.assertEqual(as_values(None), values[1]) diff --git a/tests/spanner_dbapi/test_types.py b/tests/unit/spanner_dbapi/test_types.py similarity index 82% rename from tests/spanner_dbapi/test_types.py rename to tests/unit/spanner_dbapi/test_types.py index 6c41041628..4246a43e45 100644 --- a/tests/spanner_dbapi/test_types.py +++ b/tests/unit/spanner_dbapi/test_types.py @@ -4,36 +4,48 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -import datetime -from time import timezone -from unittest import TestCase +import unittest -from google.cloud.spanner_dbapi import types +from time import timezone -class TypesTests(TestCase): +class TestTypes(unittest.TestCase): TICKS = 1572822862.9782631 + timezone # Sun 03 Nov 2019 23:14:22 UTC def test__date_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._date_from_ticks(self.TICKS) expected = datetime.date(2019, 11, 3) self.assertEqual(actual, expected) def test__time_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._time_from_ticks(self.TICKS) expected = datetime.time(23, 14, 22) self.assertEqual(actual, expected) def test__timestamp_from_ticks(self): + import datetime + + from google.cloud.spanner_dbapi import types + actual = types._timestamp_from_ticks(self.TICKS) expected = datetime.datetime(2019, 11, 3, 23, 14, 22) self.assertEqual(actual, expected) def test_type_equal(self): + from google.cloud.spanner_dbapi import types + self.assertEqual(types.BINARY, "TYPE_CODE_UNSPECIFIED") self.assertEqual(types.BINARY, "BYTES") self.assertEqual(types.BINARY, "ARRAY") diff --git a/tests/spanner_dbapi/test_utils.py b/tests/unit/spanner_dbapi/test_utils.py similarity index 67% rename from tests/spanner_dbapi/test_utils.py rename to tests/unit/spanner_dbapi/test_utils.py index 2ec10eefaf..90e1b7cf04 100644 --- a/tests/spanner_dbapi/test_utils.py +++ b/tests/unit/spanner_dbapi/test_utils.py @@ -4,13 +4,13 @@ # license that can be found in the LICENSE file or at # https://developers.google.com/open-source/licenses/bsd -from unittest import TestCase +import unittest -from google.cloud.spanner_dbapi.utils import PeekIterator - -class UtilsTests(TestCase): +class TestUtils(unittest.TestCase): def test_PeekIterator(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + cases = [ ("list", [1, 2, 3, 4, 6, 7], [1, 2, 3, 4, 6, 7]), ("iter_from_list", iter([1, 2, 3, 4, 6, 7]), [1, 2, 3, 4, 6, 7]), @@ -26,6 +26,8 @@ def test_PeekIterator(self): self.assertEqual(actual, expected) def test_peekIterator_list_rows_converted_to_tuples(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + # Cloud Spanner returns results in lists e.g. [result]. # PeekIterator is used by BaseCursor in its fetch* methods. # This test ensures that anything passed into PeekIterator @@ -47,7 +49,24 @@ def test_peekIterator_list_rows_converted_to_tuples(self): self.assertEqual(next(pit), ("Clark", "Kent")) def test_peekIterator_nonlist_rows_unconverted(self): + from google.cloud.spanner_dbapi.utils import PeekIterator + pi = PeekIterator(["a", "b", "c", "d", "e"]) got = list(pi) want = ["a", "b", "c", "d", "e"] self.assertEqual(got, want, "Values should be returned unchanged") + + def test_backtick_unicode(self): + from google.cloud.spanner_dbapi.utils import backtick_unicode + + cases = [ + ("SELECT (1) as foo WHERE 1=1", "SELECT (1) as foo WHERE 1=1"), + ("SELECT (1) as föö", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö`", "SELECT (1) as `föö`"), + ("SELECT (1) as `föö` `umläut", "SELECT (1) as `föö` `umläut"), + ("SELECT (1) as `föö", "SELECT (1) as `föö"), + ] + for sql, want in cases: + with self.subTest(sql=sql): + got = backtick_unicode(sql) + self.assertEqual(got, want)