diff --git a/google/cloud/spanner_dbapi/_helpers.py b/google/cloud/spanner_dbapi/_helpers.py index 02901ffc3a..c7f9e59afb 100644 --- a/google/cloud/spanner_dbapi/_helpers.py +++ b/google/cloud/spanner_dbapi/_helpers.py @@ -12,8 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -from google.cloud.spanner_dbapi.parse_utils import get_param_types -from google.cloud.spanner_dbapi.parse_utils import sql_pyformat_args_to_spanner from google.cloud.spanner_v1 import param_types @@ -47,24 +45,6 @@ } -def _execute_insert_heterogenous( - transaction, - sql_params_list, - request_options=None, -): - for sql, params in sql_params_list: - sql, params = sql_pyformat_args_to_spanner(sql, params) - transaction.execute_update( - sql, params, get_param_types(params), request_options=request_options - ) - - -def handle_insert(connection, sql, params): - return connection.database.run_in_transaction( - _execute_insert_heterogenous, ((sql, params),), connection.request_options - ) - - class ColumnInfo: """Row column description object.""" diff --git a/google/cloud/spanner_dbapi/connection.py b/google/cloud/spanner_dbapi/connection.py index 75263400f8..a1d46d3efe 100644 --- a/google/cloud/spanner_dbapi/connection.py +++ b/google/cloud/spanner_dbapi/connection.py @@ -24,7 +24,6 @@ from google.cloud.spanner_v1.session import _get_retry_delay from google.cloud.spanner_v1.snapshot import Snapshot -from google.cloud.spanner_dbapi._helpers import _execute_insert_heterogenous from google.cloud.spanner_dbapi.checksum import _compare_checksums from google.cloud.spanner_dbapi.checksum import ResultsChecksum from google.cloud.spanner_dbapi.cursor import Cursor @@ -450,15 +449,6 @@ def run_statement(self, statement, retried=False): if not retried: self._statements.append(statement) - if statement.is_insert: - _execute_insert_heterogenous( - transaction, ((statement.sql, statement.params),), self.request_options - ) - return ( - iter(()), - ResultsChecksum() if retried else statement.checksum, - ) - return ( transaction.execute_sql( statement.sql, diff --git a/google/cloud/spanner_dbapi/cursor.py b/google/cloud/spanner_dbapi/cursor.py index f8220d2c68..ac3888f35d 100644 --- a/google/cloud/spanner_dbapi/cursor.py +++ b/google/cloud/spanner_dbapi/cursor.py @@ -47,7 +47,7 @@ _UNSET_COUNT = -1 ColumnDetails = namedtuple("column_details", ["null_ok", "spanner_type"]) -Statement = namedtuple("Statement", "sql, params, param_types, checksum, is_insert") +Statement = namedtuple("Statement", "sql, params, param_types, checksum") def check_not_closed(function): @@ -137,14 +137,21 @@ def description(self): @property def rowcount(self): - """The number of rows updated by the last UPDATE, DELETE request's `execute()` call. + """The number of rows updated by the last INSERT, UPDATE, DELETE request's `execute()` call. For SELECT requests the rowcount returns -1. :rtype: int - :returns: The number of rows updated by the last UPDATE, DELETE request's .execute*() call. + :returns: The number of rows updated by the last INSERT, UPDATE, DELETE request's .execute*() call. """ - return self._row_count + if self._row_count != _UNSET_COUNT or self._result_set is None: + return self._row_count + + stats = getattr(self._result_set, "stats", None) + if stats is not None and "row_count_exact" in stats: + return stats.row_count_exact + + return _UNSET_COUNT @check_not_closed def callproc(self, procname, args=None): @@ -171,17 +178,11 @@ def close(self): self._is_closed = True def _do_execute_update(self, transaction, sql, params): - result = transaction.execute_update( - sql, - params=params, - param_types=get_param_types(params), - request_options=self.connection.request_options, + self._result_set = transaction.execute_sql( + sql, params=params, param_types=get_param_types(params) ) - self._itr = None - if type(result) == int: - self._row_count = result - - return result + self._itr = PeekIterator(self._result_set) + self._row_count = _UNSET_COUNT def _do_batch_update(self, transaction, statements, many_result_set): status, res = transaction.batch_update(statements) @@ -227,7 +228,9 @@ def execute(self, sql, args=None): :type args: list :param args: Additional parameters to supplement the SQL query. """ + self._itr = None self._result_set = None + self._row_count = _UNSET_COUNT try: if self.connection.read_only: @@ -249,18 +252,14 @@ def execute(self, sql, args=None): if class_ == parse_utils.STMT_UPDATING: sql = parse_utils.ensure_where_clause(sql) - if class_ != parse_utils.STMT_INSERT: - sql, args = sql_pyformat_args_to_spanner(sql, args or None) + sql, args = sql_pyformat_args_to_spanner(sql, args or None) if not self.connection.autocommit: statement = Statement( sql, args, - get_param_types(args or None) - if class_ != parse_utils.STMT_INSERT - else {}, + get_param_types(args or None), ResultsChecksum(), - class_ == parse_utils.STMT_INSERT, ) ( @@ -277,8 +276,6 @@ def execute(self, sql, args=None): if class_ == parse_utils.STMT_NON_UPDATING: self._handle_DQL(sql, args or None) - elif class_ == parse_utils.STMT_INSERT: - _helpers.handle_insert(self.connection, sql, args or None) else: self.connection.database.run_in_transaction( self._do_execute_update, @@ -304,6 +301,10 @@ def executemany(self, operation, seq_of_params): :param seq_of_params: Sequence of additional parameters to run the query with. """ + self._itr = None + self._result_set = None + self._row_count = _UNSET_COUNT + class_ = parse_utils.classify_stmt(operation) if class_ == parse_utils.STMT_DDL: raise ProgrammingError( @@ -327,6 +328,7 @@ def executemany(self, operation, seq_of_params): ) else: retried = False + total_row_count = 0 while True: try: transaction = self.connection.transaction_checkout() @@ -341,12 +343,14 @@ def executemany(self, operation, seq_of_params): many_result_set.add_iter(res) res_checksum.consume_result(res) res_checksum.consume_result(status.code) + total_row_count += sum([max(val, 0) for val in res]) if status.code == ABORTED: self.connection._transaction = None raise Aborted(status.message) elif status.code != OK: raise OperationalError(status.message) + self._row_count = total_row_count break except Aborted: self.connection.retry_transaction() diff --git a/tests/system/_helpers.py b/tests/system/_helpers.py index fba1f1a5a5..60926b216e 100644 --- a/tests/system/_helpers.py +++ b/tests/system/_helpers.py @@ -30,6 +30,9 @@ INSTANCE_ID_DEFAULT = "google-cloud-python-systest" INSTANCE_ID = os.environ.get(INSTANCE_ID_ENVVAR, INSTANCE_ID_DEFAULT) +API_ENDPOINT_ENVVAR = "GOOGLE_CLOUD_TESTS_SPANNER_HOST" +API_ENDPOINT = os.getenv(API_ENDPOINT_ENVVAR) + SKIP_BACKUP_TESTS_ENVVAR = "SKIP_BACKUP_TESTS" SKIP_BACKUP_TESTS = os.getenv(SKIP_BACKUP_TESTS_ENVVAR) is not None diff --git a/tests/system/conftest.py b/tests/system/conftest.py index 3d6706b582..fdeab14c8f 100644 --- a/tests/system/conftest.py +++ b/tests/system/conftest.py @@ -85,7 +85,10 @@ def spanner_client(): credentials=credentials, ) else: - return spanner_v1.Client() # use google.auth.default credentials + client_options = {"api_endpoint": _helpers.API_ENDPOINT} + return spanner_v1.Client( + client_options=client_options + ) # use google.auth.default credentials @pytest.fixture(scope="session") diff --git a/tests/system/test_dbapi.py b/tests/system/test_dbapi.py index 7327ef1d0d..0b92d7a15d 100644 --- a/tests/system/test_dbapi.py +++ b/tests/system/test_dbapi.py @@ -501,3 +501,126 @@ def test_staleness(shared_instance, dbapi_database): assert len(cursor.fetchall()) == 1 conn.close() + + +@pytest.mark.parametrize("autocommit", [False, True]) +def test_rowcount(shared_instance, dbapi_database, autocommit): + conn = Connection(shared_instance, dbapi_database) + conn.autocommit = autocommit + cur = conn.cursor() + + cur.execute( + """ + CREATE TABLE Singers ( + SingerId INT64 NOT NULL, + Name STRING(1024), + ) PRIMARY KEY (SingerId) + """ + ) + conn.commit() + + # executemany sets rowcount to the total modified rows + rows = [(i, f"Singer {i}") for i in range(100)] + cur.executemany("INSERT INTO Singers (SingerId, Name) VALUES (%s, %s)", rows[:98]) + assert cur.rowcount == 98 + + # execute with INSERT + cur.execute( + "INSERT INTO Singers (SingerId, Name) VALUES (%s, %s), (%s, %s)", + [x for row in rows[98:] for x in row], + ) + assert cur.rowcount == 2 + + # execute with UPDATE + cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25") + assert cur.rowcount == 25 + + # execute with SELECT + cur.execute("SELECT Name FROM Singers WHERE SingerId < 75") + assert len(cur.fetchall()) == 75 + # rowcount is not available for SELECT + assert cur.rowcount == -1 + + # execute with DELETE + cur.execute("DELETE FROM Singers") + assert cur.rowcount == 100 + + # execute with UPDATE matching 0 rows + cur.execute("UPDATE Singers SET Name = 'Cher' WHERE SingerId < 25") + assert cur.rowcount == 0 + + conn.commit() + cur.execute("DROP TABLE Singers") + conn.commit() + + +@pytest.mark.parametrize("autocommit", [False, True]) +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_dml_returning_insert(shared_instance, dbapi_database, autocommit): + conn = Connection(shared_instance, dbapi_database) + conn.autocommit = autocommit + cur = conn.cursor() + cur.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com') +THEN RETURN contact_id, first_name + """ + ) + assert cur.fetchone() == (1, "first-name") + assert cur.rowcount == 1 + conn.commit() + + +@pytest.mark.parametrize("autocommit", [False, True]) +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_dml_returning_update(shared_instance, dbapi_database, autocommit): + conn = Connection(shared_instance, dbapi_database) + conn.autocommit = autocommit + cur = conn.cursor() + cur.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com') + """ + ) + assert cur.rowcount == 1 + cur.execute( + """ +UPDATE contacts SET first_name = 'new-name' WHERE contact_id = 1 +THEN RETURN contact_id, first_name + """ + ) + assert cur.fetchone() == (1, "new-name") + assert cur.rowcount == 1 + conn.commit() + + +@pytest.mark.parametrize("autocommit", [False, True]) +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_dml_returning_delete(shared_instance, dbapi_database, autocommit): + conn = Connection(shared_instance, dbapi_database) + conn.autocommit = autocommit + cur = conn.cursor() + cur.execute( + """ +INSERT INTO contacts (contact_id, first_name, last_name, email) +VALUES (1, 'first-name', 'last-name', 'test.email@example.com') + """ + ) + assert cur.rowcount == 1 + cur.execute( + """ +DELETE FROM contacts WHERE contact_id = 1 +THEN RETURN contact_id, first_name + """ + ) + assert cur.fetchone() == (1, "first-name") + assert cur.rowcount == 1 + conn.commit() diff --git a/tests/system/test_session_api.py b/tests/system/test_session_api.py index 8e7b65d95e..aedcbcaa55 100644 --- a/tests/system/test_session_api.py +++ b/tests/system/test_session_api.py @@ -635,12 +635,30 @@ def test_transaction_read_and_insert_or_update_then_commit( def _generate_insert_statements(): + for row in _sample_data.ROW_DATA: + yield _generate_insert_statement(row) + + +def _generate_insert_statement(row): table = _sample_data.TABLE column_list = ", ".join(_sample_data.COLUMNS) + row_data = "{}, '{}', '{}', '{}'".format(*row) + return f"INSERT INTO {table} ({column_list}) VALUES ({row_data})" - for row in _sample_data.ROW_DATA: - row_data = "{}, '{}', '{}', '{}'".format(*row) - yield f"INSERT INTO {table} ({column_list}) VALUES ({row_data})" + +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def _generate_insert_returning_statement(row, database_dialect): + table = _sample_data.TABLE + column_list = ", ".join(_sample_data.COLUMNS) + row_data = "{}, '{}', '{}', '{}'".format(*row) + returning = ( + f"RETURNING {column_list}" + if database_dialect == DatabaseDialect.POSTGRESQL + else f"THEN RETURN {column_list}" + ) + return f"INSERT INTO {table} ({column_list}) VALUES ({row_data}) {returning}" @_helpers.retry_mabye_conflict @@ -742,6 +760,98 @@ def test_transaction_execute_update_then_insert_commit( # [END spanner_test_dml_with_mutation] +@_helpers.retry_mabye_conflict +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_transaction_execute_sql_dml_returning( + sessions_database, sessions_to_delete, database_dialect +): + sd = _sample_data + + session = sessions_database.session() + session.create() + sessions_to_delete.append(session) + + with session.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + + with session.transaction() as transaction: + for row in sd.ROW_DATA: + insert_statement = _generate_insert_returning_statement( + row, database_dialect + ) + results = transaction.execute_sql(insert_statement) + returned = results.one() + assert list(row) == list(returned) + row_count = results.stats.row_count_exact + assert row_count == 1 + + rows = list(session.read(sd.TABLE, sd.COLUMNS, sd.ALL)) + sd._check_rows_data(rows) + + +@_helpers.retry_mabye_conflict +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_transaction_execute_update_dml_returning( + sessions_database, sessions_to_delete, database_dialect +): + sd = _sample_data + + session = sessions_database.session() + session.create() + sessions_to_delete.append(session) + + with session.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + + with session.transaction() as transaction: + for row in sd.ROW_DATA: + insert_statement = _generate_insert_returning_statement( + row, database_dialect + ) + row_count = transaction.execute_update(insert_statement) + assert row_count == 1 + + rows = list(session.read(sd.TABLE, sd.COLUMNS, sd.ALL)) + sd._check_rows_data(rows) + + +@_helpers.retry_mabye_conflict +@pytest.mark.skipif( + _helpers.USE_EMULATOR, reason="Emulator does not support DML Returning." +) +def test_transaction_batch_update_dml_returning( + sessions_database, sessions_to_delete, database_dialect +): + sd = _sample_data + + session = sessions_database.session() + session.create() + sessions_to_delete.append(session) + + with session.batch() as batch: + batch.delete(sd.TABLE, sd.ALL) + + with session.transaction() as transaction: + insert_statements = [ + _generate_insert_returning_statement(row, database_dialect) + for row in sd.ROW_DATA + ] + + status, row_counts = transaction.batch_update(insert_statements) + _check_batch_status(status.code) + assert len(row_counts) == 3 + + for row_count in row_counts: + assert row_count == 1 + + rows = list(session.read(sd.TABLE, sd.COLUMNS, sd.ALL)) + sd._check_rows_data(rows) + + def test_transaction_batch_update_success( sessions_database, sessions_to_delete, database_dialect ): diff --git a/tests/unit/spanner_dbapi/test__helpers.py b/tests/unit/spanner_dbapi/test__helpers.py index c770ff6e4b..01302707b5 100644 --- a/tests/unit/spanner_dbapi/test__helpers.py +++ b/tests/unit/spanner_dbapi/test__helpers.py @@ -14,75 +14,9 @@ """Cloud Spanner DB-API Connection class unit tests.""" -import mock import unittest -class TestHelpers(unittest.TestCase): - def test__execute_insert_heterogenous(self): - from google.cloud.spanner_dbapi import _helpers - - sql = "sql" - params = (sql, None) - with mock.patch( - "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", - return_value=params, - ) as mock_pyformat: - with mock.patch( - "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None - ) as mock_param_types: - transaction = mock.MagicMock() - transaction.execute_update = mock_update = mock.MagicMock() - _helpers._execute_insert_heterogenous(transaction, (params,)) - - mock_pyformat.assert_called_once_with(params[0], params[1]) - mock_param_types.assert_called_once_with(None) - mock_update.assert_called_once_with( - sql, None, None, request_options=None - ) - - def test__execute_insert_heterogenous_error(self): - from google.cloud.spanner_dbapi import _helpers - from google.api_core.exceptions import Unknown - - sql = "sql" - params = (sql, None) - with mock.patch( - "google.cloud.spanner_dbapi._helpers.sql_pyformat_args_to_spanner", - return_value=params, - ) as mock_pyformat: - with mock.patch( - "google.cloud.spanner_dbapi._helpers.get_param_types", return_value=None - ) as mock_param_types: - transaction = mock.MagicMock() - transaction.execute_update = mock_update = mock.MagicMock( - side_effect=Unknown("Unknown") - ) - - with self.assertRaises(Unknown): - _helpers._execute_insert_heterogenous(transaction, (params,)) - - mock_pyformat.assert_called_once_with(params[0], params[1]) - mock_param_types.assert_called_once_with(None) - mock_update.assert_called_once_with( - sql, None, None, request_options=None - ) - - def test_handle_insert(self): - from google.cloud.spanner_dbapi import _helpers - - connection = mock.MagicMock() - connection.database.run_in_transaction = mock_run_in = mock.MagicMock() - sql = "sql" - mock_run_in.return_value = 0 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 0) - - mock_run_in.return_value = 1 - result = _helpers.handle_insert(connection, sql, None) - self.assertEqual(result, 1) - - class TestColumnInfo(unittest.TestCase): def test_ctor(self): from google.cloud.spanner_dbapi.cursor import ColumnInfo diff --git a/tests/unit/spanner_dbapi/test_connection.py b/tests/unit/spanner_dbapi/test_connection.py index 23fc098afc..090def3519 100644 --- a/tests/unit/spanner_dbapi/test_connection.py +++ b/tests/unit/spanner_dbapi/test_connection.py @@ -364,7 +364,7 @@ def test_run_statement_wo_retried(self): connection = self._make_connection() connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum(), False) + statement = Statement(sql, params, param_types, ResultsChecksum()) connection.run_statement(statement) self.assertEqual(connection._statements[0].sql, sql) @@ -383,7 +383,7 @@ def test_run_statement_w_retried(self): connection = self._make_connection() connection.transaction_checkout = mock.Mock() - statement = Statement(sql, params, param_types, ResultsChecksum(), False) + statement = Statement(sql, params, param_types, ResultsChecksum()) connection.run_statement(statement, retried=True) self.assertEqual(len(connection._statements), 0) @@ -403,7 +403,7 @@ def test_run_statement_w_heterogenous_insert_statements(self): transaction = mock.MagicMock() connection.transaction_checkout = mock.Mock(return_value=transaction) transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum(), True) + statement = Statement(sql, params, param_types, ResultsChecksum()) connection.run_statement(statement, retried=True) @@ -424,7 +424,7 @@ def test_run_statement_w_homogeneous_insert_statements(self): transaction = mock.MagicMock() connection.transaction_checkout = mock.Mock(return_value=transaction) transaction.batch_update = mock.Mock(return_value=(Status(code=OK), 1)) - statement = Statement(sql, params, param_types, ResultsChecksum(), True) + statement = Statement(sql, params, param_types, ResultsChecksum()) connection.run_statement(statement, retried=True) @@ -476,7 +476,7 @@ def test_retry_transaction_w_checksum_match(self): run_mock = connection.run_statement = mock.Mock() run_mock.return_value = ([row], retried_checkum) - statement = Statement("SELECT 1", [], {}, checksum, False) + statement = Statement("SELECT 1", [], {}, checksum) connection._statements.append(statement) with mock.patch( @@ -506,7 +506,7 @@ def test_retry_transaction_w_checksum_mismatch(self): run_mock = connection.run_statement = mock.Mock() run_mock.return_value = ([retried_row], retried_checkum) - statement = Statement("SELECT 1", [], {}, checksum, False) + statement = Statement("SELECT 1", [], {}, checksum) connection._statements.append(statement) with self.assertRaises(RetryAborted): @@ -528,7 +528,7 @@ def test_commit_retry_aborted_statements(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) mock_transaction = mock.Mock(rolled_back=False, committed=False) connection._transaction = mock_transaction @@ -573,7 +573,7 @@ def test_retry_aborted_retry(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} @@ -605,7 +605,7 @@ def test_retry_transaction_raise_max_internal_retries(self): checksum = ResultsChecksum() checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, checksum, False) + statement = Statement("SELECT 1", [], {}, checksum) connection._statements.append(statement) with self.assertRaises(Exception): @@ -632,7 +632,7 @@ def test_retry_aborted_retry_without_delay(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) metadata_mock = mock.Mock() metadata_mock.trailing_metadata.return_value = {} @@ -664,8 +664,8 @@ def test_retry_transaction_w_multiple_statement(self): checksum.consume_result(row) retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum, False) - statement1 = Statement("SELECT 2", [], {}, checksum, False) + statement = Statement("SELECT 1", [], {}, checksum) + statement1 = Statement("SELECT 2", [], {}, checksum) connection._statements.append(statement) connection._statements.append(statement1) run_mock = connection.run_statement = mock.Mock() @@ -692,7 +692,7 @@ def test_retry_transaction_w_empty_response(self): checksum.count = 1 retried_checkum = ResultsChecksum() - statement = Statement("SELECT 1", [], {}, checksum, False) + statement = Statement("SELECT 1", [], {}, checksum) connection._statements.append(statement) run_mock = connection.run_statement = mock.Mock() run_mock.return_value = ([row], retried_checkum) @@ -901,9 +901,7 @@ def test_request_priority(self): req_opts = RequestOptions(priority=priority) - connection.run_statement( - Statement(sql, params, param_types, ResultsChecksum(), False) - ) + connection.run_statement(Statement(sql, params, param_types, ResultsChecksum())) connection._transaction.execute_sql.assert_called_with( sql, params, param_types=param_types, request_options=req_opts @@ -911,9 +909,7 @@ def test_request_priority(self): assert connection.request_priority is None # check that priority is applied for only one request - connection.run_statement( - Statement(sql, params, param_types, ResultsChecksum(), False) - ) + connection.run_statement(Statement(sql, params, param_types, ResultsChecksum())) connection._transaction.execute_sql.assert_called_with( sql, params, param_types=param_types, request_options=None diff --git a/tests/unit/spanner_dbapi/test_cursor.py b/tests/unit/spanner_dbapi/test_cursor.py index 75089362af..79ed898355 100644 --- a/tests/unit/spanner_dbapi/test_cursor.py +++ b/tests/unit/spanner_dbapi/test_cursor.py @@ -97,28 +97,23 @@ def test_close(self, mock_client): cursor.execute("SELECT * FROM database") def test_do_execute_update(self): - from google.cloud.spanner_dbapi.cursor import _UNSET_COUNT + from google.cloud.spanner_v1 import ResultSetStats connection = self._make_connection(self.INSTANCE, self.DATABASE) cursor = self._make_one(connection) transaction = mock.MagicMock() + result_set = mock.MagicMock() + result_set.stats = ResultSetStats(row_count_exact=1234) + + transaction.execute_sql.return_value = result_set + cursor._do_execute_update( + transaction=transaction, + sql="SELECT * WHERE true", + params={}, + ) - def run_helper(ret_value): - transaction.execute_update.return_value = ret_value - res = cursor._do_execute_update( - transaction=transaction, - sql="SELECT * WHERE true", - params={}, - ) - return res - - expected = "good" - self.assertEqual(run_helper(expected), expected) - self.assertEqual(cursor._row_count, _UNSET_COUNT) - - expected = 1234 - self.assertEqual(run_helper(expected), expected) - self.assertEqual(cursor._row_count, expected) + self.assertEqual(cursor._result_set, result_set) + self.assertEqual(cursor.rowcount, 1234) def test_do_batch_update(self): from google.cloud.spanner_dbapi import connect @@ -193,7 +188,7 @@ def test_execute_insert_statement_autocommit_off(self): cursor._checksum = ResultsChecksum() with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_INSERT, + return_value=parse_utils.STMT_UPDATING, ): with mock.patch( "google.cloud.spanner_dbapi.connection.Connection.run_statement", @@ -213,7 +208,7 @@ def test_execute_statement(self): with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - side_effect=[parse_utils.STMT_DDL, parse_utils.STMT_INSERT], + side_effect=[parse_utils.STMT_DDL, parse_utils.STMT_UPDATING], ) as mock_classify_stmt: sql = "sql" with self.assertRaises(ValueError): @@ -245,18 +240,6 @@ def test_execute_statement(self): cursor.execute(sql=sql) mock_handle_ddl.assert_called_once_with(sql, None) - with mock.patch( - "google.cloud.spanner_dbapi.parse_utils.classify_stmt", - return_value=parse_utils.STMT_INSERT, - ): - with mock.patch( - "google.cloud.spanner_dbapi._helpers.handle_insert", - return_value=parse_utils.STMT_INSERT, - ) as mock_handle_insert: - sql = "sql" - cursor.execute(sql=sql) - mock_handle_insert.assert_called_once_with(connection, sql, None) - with mock.patch( "google.cloud.spanner_dbapi.parse_utils.classify_stmt", return_value="other_statement", @@ -923,7 +906,7 @@ def test_fetchone_retry_aborted_statements(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch( @@ -957,7 +940,7 @@ def test_fetchone_retry_aborted_statements_checksums_mismatch(self, mock_client) cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch( @@ -1013,7 +996,7 @@ def test_fetchall_retry_aborted_statements(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch( @@ -1046,7 +1029,7 @@ def test_fetchall_retry_aborted_statements_checksums_mismatch(self, mock_client) cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch( @@ -1102,7 +1085,7 @@ def test_fetchmany_retry_aborted_statements(self, mock_client): cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch( @@ -1136,7 +1119,7 @@ def test_fetchmany_retry_aborted_statements_checksums_mismatch(self, mock_client cursor._checksum = ResultsChecksum() cursor._checksum.consume_result(row) - statement = Statement("SELECT 1", [], {}, cursor._checksum, False) + statement = Statement("SELECT 1", [], {}, cursor._checksum) connection._statements.append(statement) with mock.patch(