Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions google/cloud/spanner_v1/batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_merge_Transaction_Options,
AtomicCounter,
)
from google.cloud.spanner_v1._opentelemetry_tracing import trace_call
from google.cloud.spanner_v1 import RequestOptions
Expand Down Expand Up @@ -249,17 +250,25 @@ def commit(
observability_options=observability_options,
metadata=metadata,
), MetricsCapture():
method = functools.partial(
api.commit,
request=request,
metadata=metadata,
)
attempt = AtomicCounter()
nth_request = getattr(database, "_next_nth_request", 0)

def wrapped_method(*args, **kwargs):
method = functools.partial(
api.commit,
request=request,
metadata=database.metadata_with_request_id(
nth_request, attempt.increment(), metadata
),
)
return method(*args, **kwargs)

deadline = time.time() + kwargs.get(
"timeout_secs", DEFAULT_RETRY_TIMEOUT_SECS
)
default_retry_delay = kwargs.get("default_retry_delay", None)
response = _retry_on_aborted_exception(
method,
wrapped_method,
deadline=deadline,
default_retry_delay=default_retry_delay,
)
Expand Down
97 changes: 85 additions & 12 deletions google/cloud/spanner_v1/database.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,7 @@
from google.cloud.spanner_v1._helpers import (
_metadata_with_prefix,
_metadata_with_leader_aware_routing,
_metadata_with_request_id,
)
from google.cloud.spanner_v1.batch import Batch
from google.cloud.spanner_v1.batch import MutationGroups
Expand Down Expand Up @@ -151,6 +152,9 @@ class Database(object):

_spanner_api: SpannerClient = None

__transport_lock = threading.Lock()
__transports_to_channel_id = dict()

def __init__(
self,
database_id,
Expand Down Expand Up @@ -188,6 +192,7 @@ def __init__(
self._instance._client.default_transaction_options
)
self._proto_descriptors = proto_descriptors
self._channel_id = 0 # It'll be created when _spanner_api is created.

if pool is None:
pool = BurstyPool(database_role=database_role)
Expand Down Expand Up @@ -446,8 +451,26 @@ def spanner_api(self):
client_info=client_info,
client_options=client_options,
)

with self.__transport_lock:
transport = self._spanner_api._transport
channel_id = self.__transports_to_channel_id.get(transport, None)
if channel_id is None:
channel_id = len(self.__transports_to_channel_id) + 1
self.__transports_to_channel_id[transport] = channel_id
self._channel_id = channel_id

return self._spanner_api

def metadata_with_request_id(self, nth_request, nth_attempt, prior_metadata=[]):
return _metadata_with_request_id(
self._nth_client_id,
self._channel_id,
nth_request,
nth_attempt,
prior_metadata,
)

def __eq__(self, other):
if not isinstance(other, self.__class__):
return NotImplemented
Expand Down Expand Up @@ -490,7 +513,10 @@ def create(self):
database_dialect=self._database_dialect,
proto_descriptors=self._proto_descriptors,
)
future = api.create_database(request=request, metadata=metadata)
future = api.create_database(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

def exists(self):
Expand All @@ -506,7 +532,12 @@ def exists(self):
metadata = _metadata_with_prefix(self.name)

try:
api.get_database_ddl(database=self.name, metadata=metadata)
api.get_database_ddl(
database=self.name,
metadata=self.metadata_with_request_id(
self._next_nth_request, 1, metadata
),
)
except NotFound:
return False
return True
Expand All @@ -523,10 +554,16 @@ def reload(self):
"""
api = self._instance._client.database_admin_api
metadata = _metadata_with_prefix(self.name)
response = api.get_database_ddl(database=self.name, metadata=metadata)
response = api.get_database_ddl(
database=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
self._ddl_statements = tuple(response.statements)
self._proto_descriptors = response.proto_descriptors
response = api.get_database(name=self.name, metadata=metadata)
response = api.get_database(
name=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
self._state = DatabasePB.State(response.state)
self._create_time = response.create_time
self._restore_info = response.restore_info
Expand Down Expand Up @@ -571,7 +608,10 @@ def update_ddl(self, ddl_statements, operation_id="", proto_descriptors=None):
proto_descriptors=proto_descriptors,
)

future = api.update_database_ddl(request=request, metadata=metadata)
future = api.update_database_ddl(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

def update(self, fields):
Expand Down Expand Up @@ -609,7 +649,9 @@ def update(self, fields):
metadata = _metadata_with_prefix(self.name)

future = api.update_database(
database=database_pb, update_mask=field_mask, metadata=metadata
database=database_pb,
update_mask=field_mask,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

return future
Expand All @@ -622,7 +664,10 @@ def drop(self):
"""
api = self._instance._client.database_admin_api
metadata = _metadata_with_prefix(self.name)
api.drop_database(database=self.name, metadata=metadata)
api.drop_database(
database=self.name,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

def execute_partitioned_dml(
self,
Expand Down Expand Up @@ -711,7 +756,13 @@ def execute_pdml():
with SessionCheckout(self._pool) as session:
add_span_event(span, "Starting BeginTransaction")
txn = api.begin_transaction(
session=session.name, options=txn_options, metadata=metadata
session=session.name,
options=txn_options,
metadata=self.metadata_with_request_id(
self._next_nth_request,
1,
metadata,
),
)

txn_selector = TransactionSelector(id=txn.id)
Expand All @@ -724,6 +775,7 @@ def execute_pdml():
query_options=query_options,
request_options=request_options,
)

method = functools.partial(
api.execute_streaming_sql,
metadata=metadata,
Expand All @@ -736,6 +788,7 @@ def execute_pdml():
metadata=metadata,
transaction_selector=txn_selector,
observability_options=self.observability_options,
request_id_manager=self,
)

result_set = StreamedResultSet(iterator)
Expand All @@ -745,6 +798,16 @@ def execute_pdml():

return _retry_on_aborted(execute_pdml, DEFAULT_RETRY_BACKOFF)()

@property
def _next_nth_request(self):
if self._instance and self._instance._client:
return self._instance._client._next_nth_request
return 1

@property
def _nth_client_id(self):
return self._instance._client._nth_client_id

def session(self, labels=None, database_role=None):
"""Factory to create a session for this database.

Expand Down Expand Up @@ -965,7 +1028,8 @@ def restore(self, source):
)
future = api.restore_database(
request=request,
metadata=metadata,
# TODO: Infer the channel_id being used.
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return future

Expand Down Expand Up @@ -1034,7 +1098,10 @@ def list_database_roles(self, page_size=None):
parent=self.name,
page_size=page_size,
)
return api.list_database_roles(request=request, metadata=metadata)
return api.list_database_roles(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)

def table(self, table_id):
"""Factory to create a table object within this database.
Expand Down Expand Up @@ -1118,7 +1185,10 @@ def get_iam_policy(self, policy_version=None):
requested_policy_version=policy_version
),
)
response = api.get_iam_policy(request=request, metadata=metadata)
response = api.get_iam_policy(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return response

def set_iam_policy(self, policy):
Expand All @@ -1140,7 +1210,10 @@ def set_iam_policy(self, policy):
resource=self.name,
policy=policy,
)
response = api.set_iam_policy(request=request, metadata=metadata)
response = api.set_iam_policy(
request=request,
metadata=self.metadata_with_request_id(self._next_nth_request, 1, metadata),
)
return response

@property
Expand Down
8 changes: 6 additions & 2 deletions google/cloud/spanner_v1/pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,9 @@ def bind(self, database):
)
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
)

add_span_event(
Expand Down Expand Up @@ -561,7 +563,9 @@ def bind(self, database):
while returned_session_count < self.size:
resp = api.batch_create_sessions(
request=request,
metadata=metadata,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
)

add_span_event(
Expand Down
39 changes: 30 additions & 9 deletions google/cloud/spanner_v1/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,9 @@ def create(self):
), MetricsCapture():
session_pb = api.create_session(
request=request,
metadata=metadata,
metadata=self._database.metadata_with_request_id(
self._database._next_nth_request, 1, metadata
),
)
self._session_id = session_pb.name.split("/")[-1]

Expand All @@ -195,7 +197,8 @@ def exists(self):
current_span, "Checking if Session exists", {"session.id": self._session_id}
)

api = self._database.spanner_api
database = self._database
api = database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
if self._database._route_to_leader_enabled:
metadata.append(
Expand All @@ -212,7 +215,12 @@ def exists(self):
metadata=metadata,
) as span, MetricsCapture():
try:
api.get_session(name=self.name, metadata=metadata)
api.get_session(
name=self.name,
metadata=database.metadata_with_request_id(
database._next_nth_request, 1, metadata
),
)
if span:
span.set_attribute("session_found", True)
except NotFound:
Expand Down Expand Up @@ -242,8 +250,11 @@ def delete(self):
current_span, "Deleting Session", {"session.id": self._session_id}
)

api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
api = database.spanner_api
metadata = database.metadata_with_request_id(
database._next_nth_request, 1, _metadata_with_prefix(database.name)
)
observability_options = getattr(self._database, "observability_options", None)
with trace_call(
"CloudSpanner.DeleteSession",
Expand All @@ -255,7 +266,10 @@ def delete(self):
observability_options=observability_options,
metadata=metadata,
), MetricsCapture():
api.delete_session(name=self.name, metadata=metadata)
api.delete_session(
name=self.name,
metadata=metadata,
)

def ping(self):
"""Ping the session to keep it alive by executing "SELECT 1".
Expand All @@ -264,10 +278,17 @@ def ping(self):
"""
if self._session_id is None:
raise ValueError("Session ID not set by back-end")
api = self._database.spanner_api
metadata = _metadata_with_prefix(self._database.name)
database = self._database
api = database.spanner_api
request = ExecuteSqlRequest(session=self.name, sql="SELECT 1")
api.execute_sql(request=request, metadata=metadata)
api.execute_sql(
request=request,
metadata=database.metadata_with_request_id(
database._next_nth_request,
1,
_metadata_with_prefix(database.name),
),
)
self._last_use_time = datetime.now()

def snapshot(self, **kw):
Expand Down
Loading
Loading