1818
1919from google .protobuf import empty_pb2
2020from grpc_status .rpc_status import _Status
21+
22+ from google .cloud .spanner_v1 import (
23+ TransactionOptions ,
24+ ResultSetMetadata ,
25+ ExecuteSqlRequest ,
26+ ExecuteBatchDmlRequest ,
27+ )
2128from google .cloud .spanner_v1 .testing .mock_database_admin import DatabaseAdminServicer
2229import google .cloud .spanner_v1 .testing .spanner_database_admin_pb2_grpc as database_admin_grpc
2330import google .cloud .spanner_v1 .testing .spanner_pb2_grpc as spanner_grpc
@@ -51,23 +58,25 @@ def pop_error(self, context):
5158 context .abort_with_status (error )
5259
5360 def get_result_as_partial_result_sets (
54- self , sql : str
61+ self , sql : str , started_transaction : transaction . Transaction
5562 ) -> [result_set .PartialResultSet ]:
5663 result : result_set .ResultSet = self .get_result (sql )
5764 partials = []
5865 first = True
5966 if len (result .rows ) == 0 :
6067 partial = result_set .PartialResultSet ()
61- partial .metadata = result .metadata
68+ partial .metadata = ResultSetMetadata ( result .metadata )
6269 partials .append (partial )
6370 else :
6471 for row in result .rows :
6572 partial = result_set .PartialResultSet ()
6673 if first :
67- partial .metadata = result .metadata
74+ partial .metadata = ResultSetMetadata ( result .metadata )
6875 partial .values .extend (row )
6976 partials .append (partial )
7077 partials [len (partials ) - 1 ].stats = result .stats
78+ if started_transaction :
79+ partials [0 ].metadata .transaction = started_transaction
7180 return partials
7281
7382
@@ -129,22 +138,29 @@ def DeleteSession(self, request, context):
129138
130139 def ExecuteSql (self , request , context ):
131140 self ._requests .append (request )
132- return result_set .ResultSet ()
141+ self .mock_spanner .pop_error (context )
142+ started_transaction = self .__maybe_create_transaction (request )
143+ result : result_set .ResultSet = self .mock_spanner .get_result (request .sql )
144+ if started_transaction :
145+ result .metadata = ResultSetMetadata (result .metadata )
146+ result .metadata .transaction = started_transaction
147+ return result
133148
134149 def ExecuteStreamingSql (self , request , context ):
135150 self ._requests .append (request )
136- partials = self .mock_spanner .get_result_as_partial_result_sets (request .sql )
151+ self .mock_spanner .pop_error (context )
152+ started_transaction = self .__maybe_create_transaction (request )
153+ partials = self .mock_spanner .get_result_as_partial_result_sets (
154+ request .sql , started_transaction
155+ )
137156 for result in partials :
138157 yield result
139158
140159 def ExecuteBatchDml (self , request , context ):
141160 self ._requests .append (request )
161+ self .mock_spanner .pop_error (context )
142162 response = spanner .ExecuteBatchDmlResponse ()
143- started_transaction = None
144- if not request .transaction .begin == transaction .TransactionOptions ():
145- started_transaction = self .__create_transaction (
146- request .session , request .transaction .begin
147- )
163+ started_transaction = self .__maybe_create_transaction (request )
148164 first = True
149165 for statement in request .statements :
150166 result = self .mock_spanner .get_result (statement .sql )
@@ -170,6 +186,16 @@ def BeginTransaction(self, request, context):
170186 self ._requests .append (request )
171187 return self .__create_transaction (request .session , request .options )
172188
189+ def __maybe_create_transaction (
190+ self , request : ExecuteSqlRequest | ExecuteBatchDmlRequest
191+ ):
192+ started_transaction = None
193+ if not request .transaction .begin == TransactionOptions ():
194+ started_transaction = self .__create_transaction (
195+ request .session , request .transaction .begin
196+ )
197+ return started_transaction
198+
173199 def __create_transaction (
174200 self , session : str , options : transaction .TransactionOptions
175201 ) -> transaction .Transaction :
0 commit comments