5757)
5858from google .cloud .spanner_v1 .batch import Batch
5959from google .cloud .spanner_v1 .batch import MutationGroups
60+ from google .cloud .spanner_v1 .database_sessions_manager import DatabaseSessionsManager
6061from google .cloud .spanner_v1 .keyset import KeySet
6162from google .cloud .spanner_v1 .merged_result_set import MergedResultSet
6263from google .cloud .spanner_v1 .pool import BurstyPool
63- from google .cloud .spanner_v1 .pool import SessionCheckout
64- from google .cloud .spanner_v1 .session import Session
64+ from google .cloud .spanner_v1 .session_options import SessionOptions , TransactionType
6565from google .cloud .spanner_v1 .snapshot import _restart_on_unavailable
6666from google .cloud .spanner_v1 .snapshot import Snapshot
6767from google .cloud .spanner_v1 .streamed import StreamedResultSet
@@ -196,9 +196,9 @@ def __init__(
196196
197197 if pool is None :
198198 pool = BurstyPool (database_role = database_role )
199-
200- self ._pool = pool
201199 pool .bind (self )
200+ self ._session_manager = DatabaseSessionsManager (database = self , pool = pool )
201+
202202
203203 @classmethod
204204 def from_pb (cls , database_pb , instance , pool = None ):
@@ -462,6 +462,14 @@ def spanner_api(self):
462462
463463 return self ._spanner_api
464464
465+ @property
466+ def session_options (self ) -> SessionOptions :
467+ """Session options for the database.
468+ :rtype: :class:`~google.cloud.spanner_v1.session_options.SessionOptions`
469+ :returns: the session options
470+ """
471+ return self ._instance ._client .session_options
472+
465473 def metadata_with_request_id (
466474 self , nth_request , nth_attempt , prior_metadata = [], span = None
467475 ):
@@ -759,18 +767,31 @@ def execute_pdml():
759767 "CloudSpanner.Database.execute_partitioned_pdml" ,
760768 observability_options = self .observability_options ,
761769 ) as span , MetricsCapture ():
762- with SessionCheckout (self . _pool ) as session :
770+ with SessionCheckout (self , TransactionType . PARTITIONED ) as session :
763771 add_span_event (span , "Starting BeginTransaction" )
764- txn = api .begin_transaction (
765- session = session .name ,
766- options = txn_options ,
767- metadata = self .metadata_with_request_id (
768- self ._next_nth_request ,
769- 1 ,
770- metadata ,
771- span ,
772- ),
773- )
772+ try :
773+ txn = api .begin_transaction (
774+ session = session .name ,
775+ options = txn_options ,
776+ metadata = self .metadata_with_request_id (
777+ self ._next_nth_request ,
778+ 1 ,
779+ metadata ,
780+ span ,
781+ ),
782+ )
783+ # If partitioned DML is not supported with multiplexed sessions,
784+ # disable multiplexed sessions for partitioned transactions before
785+ # re-raising the error.
786+ except NotImplementedError as exc :
787+ if (
788+ "Transaction type partitioned_dml not supported with multiplexed sessions"
789+ in str (exc )
790+ ):
791+ self .session_options .disable_multiplexed (
792+ self .logger , TransactionType .PARTITIONED
793+ )
794+ raise exc
774795
775796 txn_selector = TransactionSelector (id = txn .id )
776797
@@ -792,6 +813,7 @@ def execute_pdml():
792813 method = method ,
793814 trace_name = "CloudSpanner.ExecuteStreamingSql" ,
794815 request = request ,
816+ session = session ,
795817 metadata = metadata ,
796818 transaction_selector = txn_selector ,
797819 observability_options = self .observability_options ,
@@ -817,23 +839,6 @@ def _nth_client_id(self):
817839 return self ._instance ._client ._nth_client_id
818840 return 0
819841
820- def session (self , labels = None , database_role = None ):
821- """Factory to create a session for this database.
822-
823- :type labels: dict (str -> str) or None
824- :param labels: (Optional) user-assigned labels for the session.
825-
826- :type database_role: str
827- :param database_role: (Optional) user-assigned database_role for the session.
828-
829- :rtype: :class:`~google.cloud.spanner_v1.session.Session`
830- :returns: a session bound to this database.
831- """
832- # If role is specified in param, then that role is used
833- # instead.
834- role = database_role or self ._database_role
835- return Session (self , labels = labels , database_role = role )
836-
837842 def snapshot (self , ** kw ):
838843 """Return an object which wraps a snapshot.
839844
@@ -995,7 +1000,7 @@ def run_in_transaction(self, func, *args, **kw):
9951000 # Check out a session and run the function in a transaction; once
9961001 # done, flip the sanity check bit back.
9971002 try :
998- with SessionCheckout (self . _pool ) as session :
1003+ with SessionCheckout (self ) as session :
9991004 return session .run_in_transaction (func , * args , ** kw )
10001005 finally :
10011006 self ._local .transaction_running = False
@@ -1241,6 +1246,50 @@ def observability_options(self):
12411246 return opts
12421247
12431248
1249+ class SessionCheckout (object ):
1250+ """Context manager for using a session from a database.
1251+ :type database: :class:`~google.cloud.spanner_v1.database.Database`
1252+ :param database: database to use the session from
1253+ """
1254+
1255+ _session = None # Not checked out until '__enter__'.
1256+
1257+ def __init__ (
1258+ self ,
1259+ database , # type: ignore
1260+ transaction_type : TransactionType = TransactionType .READ_WRITE ,
1261+ ):
1262+ # Move import here to avoid circular import
1263+ from google .cloud .spanner_v1 .database import Database
1264+ if not isinstance (database , Database ):
1265+ raise TypeError (
1266+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1267+ class_name = self .__class__ .__name__ ,
1268+ expected_class_name = Database .__name__ ,
1269+ actual_class_name = database .__class__ .__name__ ,
1270+ )
1271+ )
1272+
1273+ if not isinstance (transaction_type , TransactionType ):
1274+ raise TypeError (
1275+ "{class_name} must receive an instance of {expected_class_name}. Received: {actual_class_name}" .format (
1276+ class_name = self .__class__ .__name__ ,
1277+ expected_class_name = TransactionType .__name__ ,
1278+ actual_class_name = transaction_type .__class__ .__name__ ,
1279+ )
1280+ )
1281+
1282+ self ._database = database
1283+ self ._transaction_type = transaction_type
1284+
1285+ def __enter__ (self ):
1286+ session_manager = self ._database ._session_manager
1287+ self ._session = session_manager .get_session (self ._transaction_type )
1288+ return self ._session
1289+
1290+ def __exit__ (self , * ignored ):
1291+ self ._database ._session_manager .put_session (self ._session )
1292+
12441293class BatchCheckout (object ):
12451294 """Context manager for using a batch from a database.
12461295
@@ -1929,3 +1978,4 @@ def _retry_on_aborted(func, retry_config):
19291978 """
19301979 retry = retry_config .with_predicate (if_exception_type (Aborted ))
19311980 return retry (func )
1981+
0 commit comments