@@ -131,6 +131,161 @@ def test_propagation(enable_extended_tracing):
131131 test_propagation (False )
132132
133133
134+ @pytest .mark .skipif (
135+ not _helpers .USE_EMULATOR ,
136+ reason = "Emulator needed to run this tests" ,
137+ )
138+ @pytest .mark .skipif (
139+ not HAS_OTEL_INSTALLED ,
140+ reason = "Tracing requires OpenTelemetry" ,
141+ )
142+ def test_transaction_abort_then_retry_spans ():
143+ from google .auth .credentials import AnonymousCredentials
144+ from google .api_core .exceptions import Aborted
145+ from google .rpc import code_pb2
146+ from opentelemetry .sdk .trace .export import SimpleSpanProcessor
147+ from opentelemetry .sdk .trace .export .in_memory_span_exporter import (
148+ InMemorySpanExporter ,
149+ )
150+ from opentelemetry .trace .status import StatusCode
151+ from opentelemetry .sdk .trace import TracerProvider
152+ from opentelemetry .sdk .trace .sampling import ALWAYS_ON
153+ from opentelemetry import trace
154+
155+ PROJECT = _helpers .EMULATOR_PROJECT
156+ CONFIGURATION_NAME = "config-name"
157+ INSTANCE_ID = _helpers .INSTANCE_ID
158+ DISPLAY_NAME = "display-name"
159+ DATABASE_ID = _helpers .unique_id ("temp_db" )
160+ NODE_COUNT = 5
161+ LABELS = {"test" : "true" }
162+
163+ counters = dict (aborted = 0 )
164+ already_aborted = False
165+
166+ def select_in_txn (txn ):
167+ from google .rpc import error_details_pb2
168+
169+ results = txn .execute_sql ("SELECT 1" )
170+ for row in results :
171+ _ = row
172+
173+ if counters ["aborted" ] == 0 :
174+ counters ["aborted" ] = 1
175+ raise Aborted (
176+ "Thrown from ClientInterceptor for testing" ,
177+ errors = [_helpers .FauxCall (code_pb2 .ABORTED )],
178+ )
179+
180+ tracer_provider = TracerProvider (sampler = ALWAYS_ON )
181+ trace_exporter = InMemorySpanExporter ()
182+ tracer_provider .add_span_processor (SimpleSpanProcessor (trace_exporter ))
183+ observability_options = dict (
184+ tracer_provider = tracer_provider ,
185+ enable_extended_tracing = True ,
186+ )
187+
188+ client = Client (
189+ project = PROJECT ,
190+ observability_options = observability_options ,
191+ credentials = AnonymousCredentials (),
192+ )
193+
194+ instance = client .instance (
195+ INSTANCE_ID ,
196+ CONFIGURATION_NAME ,
197+ display_name = DISPLAY_NAME ,
198+ node_count = NODE_COUNT ,
199+ labels = LABELS ,
200+ )
201+
202+ try :
203+ instance .create ()
204+ except Exception :
205+ pass
206+
207+ db = instance .database (DATABASE_ID )
208+ try :
209+ db .create ()
210+ except Exception :
211+ pass
212+
213+ db .run_in_transaction (select_in_txn )
214+
215+ span_list = trace_exporter .get_finished_spans ()
216+ got_span_names = [span .name for span in span_list ]
217+ want_span_names = [
218+ "CloudSpanner.CreateSession" ,
219+ "CloudSpanner.Transaction.execute_streaming_sql" ,
220+ "CloudSpanner.Transaction.execute_streaming_sql" ,
221+ "CloudSpanner.Transaction.commit" ,
222+ "CloudSpanner.Session.run_in_transaction" ,
223+ "CloudSpanner.Database.run_in_transaction" ,
224+ ]
225+
226+ assert got_span_names == want_span_names
227+
228+ got_events = []
229+ got_statuses = []
230+
231+ # Some event attributes are noisy/highly ephemeral
232+ # and can't be directly compared against.
233+ imprecise_event_attributes = ["exception.stacktrace" , "delay_seconds" ]
234+ for span in span_list :
235+ got_statuses .append (
236+ (span .name , span .status .status_code , span .status .description )
237+ )
238+ for event in span .events :
239+ evt_attributes = event .attributes .copy ()
240+ for attr_name in imprecise_event_attributes :
241+ if attr_name in evt_attributes :
242+ evt_attributes [attr_name ] = "EPHEMERAL"
243+
244+ got_events .append ((event .name , evt_attributes ))
245+
246+ # Check for the series of events
247+ want_events = [
248+ ("Starting Commit" , {}),
249+ ("Commit Done" , {}),
250+ ("Using Transaction" , {"attempt" : 1 }),
251+ (
252+ "exception" ,
253+ {
254+ "exception.type" : "google.api_core.exceptions.Aborted" ,
255+ "exception.message" : "409 Thrown from ClientInterceptor for testing" ,
256+ "exception.stacktrace" : "EPHEMERAL" ,
257+ "exception.escaped" : "False" ,
258+ },
259+ ),
260+ (
261+ "Transaction was aborted in user operation, retrying" ,
262+ {"delay_seconds" : "EPHEMERAL" , "attempt" : 1 },
263+ ),
264+ ("Using Transaction" , {"attempt" : 2 }),
265+ ("Acquiring session" , {"kind" : "BurstyPool" }),
266+ ("Waiting for a session to become available" , {"kind" : "BurstyPool" }),
267+ ("No sessions available in pool. Creating session" , {"kind" : "BurstyPool" }),
268+ ("Creating Session" , {}),
269+ ]
270+ assert got_events == want_events
271+
272+ # Check for the statues.
273+ codes = StatusCode
274+ want_statuses = [
275+ ("CloudSpanner.CreateSession" , codes .OK , None ),
276+ ("CloudSpanner.Transaction.execute_streaming_sql" , codes .OK , None ),
277+ ("CloudSpanner.Transaction.execute_streaming_sql" , codes .OK , None ),
278+ ("CloudSpanner.Transaction.commit" , codes .OK , None ),
279+ (
280+ "CloudSpanner.Session.run_in_transaction" ,
281+ codes .ERROR ,
282+ "409 Thrown from ClientInterceptor for testing" ,
283+ ),
284+ ("CloudSpanner.Database.run_in_transaction" , codes .OK , None ),
285+ ]
286+ assert got_statuses == want_statuses
287+
288+
134289def _make_credentials ():
135290 from google .auth .credentials import AnonymousCredentials
136291
0 commit comments