Skip to content

Commit 0d8a7da

Browse files
committed
feat: Multiplexed sessions - Update Connection to use multiplexed sessions, add unit tests.
Signed-off-by: Taylor Curran <[email protected]>
1 parent 7fcd202 commit 0d8a7da

File tree

3 files changed

+91
-53
lines changed

3 files changed

+91
-53
lines changed

google/cloud/spanner_dbapi/connection.py

Lines changed: 14 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
from google.cloud.spanner_dbapi.transaction_helper import TransactionRetryHelper
2929
from google.cloud.spanner_dbapi.cursor import Cursor
3030
from google.cloud.spanner_v1 import RequestOptions, TransactionOptions
31+
from google.cloud.spanner_v1.session_options import TransactionType
3132
from google.cloud.spanner_v1.snapshot import Snapshot
3233

3334
from google.cloud.spanner_dbapi.exceptions import (
@@ -356,8 +357,16 @@ def _session_checkout(self):
356357
"""
357358
if self.database is None:
358359
raise ValueError("Database needs to be passed for this operation")
360+
359361
if not self._session:
360-
self._session = self.database._pool.get()
362+
transaction_type = (
363+
TransactionType.READ_ONLY
364+
if self.read_only
365+
else TransactionType.READ_WRITE
366+
)
367+
self._session = self.database._sessions_manager.get_session(
368+
transaction_type
369+
)
361370

362371
return self._session
363372

@@ -368,9 +377,11 @@ def _release_session(self):
368377
"""
369378
if self._session is None:
370379
return
380+
371381
if self.database is None:
372382
raise ValueError("Database needs to be passed for this operation")
373-
self.database._pool.put(self._session)
383+
384+
self.database._sessions_manager.put_session(self._session)
374385
self._session = None
375386

376387
def transaction_checkout(self):
@@ -432,7 +443,7 @@ def close(self):
432443
self._transaction.rollback()
433444

434445
if self._own_pool and self.database:
435-
self.database._pool.clear()
446+
self.database._sessions_manager._pool.clear()
436447

437448
self.is_closed = True
438449

tests/_builders.py

Lines changed: 46 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,16 @@
1111
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
14-
14+
from logging import Logger
1515
from mock import create_autospec
16+
from typing import Mapping
17+
18+
from google.cloud.spanner_dbapi import Connection
19+
from google.cloud.spanner_v1 import SpannerClient
20+
from google.cloud.spanner_v1.client import Client
21+
from google.cloud.spanner_v1.database import Database
22+
from google.cloud.spanner_v1.instance import Instance
23+
from google.cloud.spanner_v1.session import Session
1624

1725
# Default values used to populate required or expected attributes.
1826
# Tests should not depend on them: if a test requires a specific
@@ -22,67 +30,78 @@
2230
_DATABASE_ID = "default-database-id"
2331

2432

25-
def build_logger():
26-
"""Builds and returns a logger for testing."""
27-
from logging import Logger
28-
29-
return create_autospec(Logger, instance=True)
30-
31-
32-
# Client objects
33-
# --------------
34-
35-
36-
def build_client(**kwargs):
33+
def build_client(**kwargs: Mapping) -> Client:
3734
"""Builds and returns a client for testing using the given arguments.
3835
If a required argument is not provided, a default value will be used."""
39-
from google.cloud.spanner_v1 import Client
4036

4137
if "project" not in kwargs:
4238
kwargs["project"] = _PROJECT_ID
4339

4440
return Client(**kwargs)
4541

4642

47-
def build_database(**kwargs):
43+
def build_connection(**kwargs: Mapping) -> Connection:
44+
"""Builds and returns a connection for testing using the given arguments.
45+
If a required argument is not provided, a default value will be used."""
46+
47+
if "instance" not in kwargs:
48+
kwargs["instance"] = build_instance()
49+
50+
if "database" not in kwargs:
51+
kwargs["database"] = build_database(instance=kwargs["instance"])
52+
53+
return Connection(**kwargs)
54+
55+
56+
def build_database(**kwargs: Mapping) -> Database:
4857
"""Builds and returns a database for testing using the given arguments.
49-
If a required argument is not provided, a default value will be used.."""
50-
from google.cloud.spanner_v1.database import Database
58+
If a required argument is not provided, a default value will be used."""
5159

5260
if "database_id" not in kwargs:
5361
kwargs["database_id"] = _DATABASE_ID
5462

5563
if "logger" not in kwargs:
5664
kwargs["logger"] = build_logger()
5765

58-
if "instance" not in kwargs or isinstance(kwargs["instance"], dict):
59-
instance_args = kwargs.pop("instance", {})
60-
kwargs["instance"] = build_instance(**instance_args)
66+
if "instance" not in kwargs:
67+
kwargs["instance"] = build_instance()
6168

6269
database = Database(**kwargs)
6370
database._spanner_api = build_spanner_api()
6471

6572
return database
6673

6774

68-
def build_instance(**kwargs):
75+
def build_instance(**kwargs: Mapping) -> Instance:
6976
"""Builds and returns an instance for testing using the given arguments.
7077
If a required argument is not provided, a default value will be used."""
71-
from google.cloud.spanner_v1.instance import Instance
7278

7379
if "instance_id" not in kwargs:
7480
kwargs["instance_id"] = _INSTANCE_ID
7581

76-
if "client" not in kwargs or isinstance(kwargs["client"], dict):
77-
client_args = kwargs.pop("client", {})
78-
kwargs["client"] = build_client(**client_args)
82+
if "client" not in kwargs:
83+
kwargs["client"] = build_client()
7984

8085
return Instance(**kwargs)
8186

8287

83-
def build_spanner_api():
88+
def build_logger() -> Logger:
89+
"""Builds and returns a logger for testing."""
90+
return create_autospec(Logger, instance=True)
91+
92+
93+
def build_session(**kwargs: Mapping) -> Session:
94+
"""Builds and returns a session for testing using the given arguments.
95+
If a required argument is not provided, a default value will be used."""
96+
97+
if "database" not in kwargs:
98+
kwargs["database"] = build_database()
99+
100+
return Session(**kwargs)
101+
102+
103+
def build_spanner_api() -> SpannerClient:
84104
"""Builds and returns a mock Spanner Client API for testing using the given arguments.
85105
Commonly used methods are mocked to return default values."""
86-
from google.cloud.spanner_v1 import SpannerClient
87106

88107
return create_autospec(SpannerClient, instance=True)

tests/unit/spanner_dbapi/test_connection.py

Lines changed: 31 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
ClientSideStatementType,
3838
AutocommitDmlMode,
3939
)
40+
from google.cloud.spanner_v1.session_options import TransactionType
41+
from tests._builders import build_connection, build_session
4042

4143
PROJECT = "test-project"
4244
INSTANCE = "test-instance"
@@ -151,42 +153,48 @@ def test_read_only_connection(self):
151153
connection.read_only = False
152154
self.assertFalse(connection.read_only)
153155

154-
@staticmethod
155-
def _make_pool():
156-
from google.cloud.spanner_v1.pool import AbstractSessionPool
156+
def test__session_checkout_read_only(self):
157+
connection = build_connection(read_only=True)
158+
database = connection._database
159+
sessions_manager = database._sessions_manager
157160

158-
return mock.create_autospec(AbstractSessionPool)
161+
expected_session = build_session(database=database)
162+
sessions_manager.get_session = mock.MagicMock(return_value=expected_session)
159163

160-
@mock.patch("google.cloud.spanner_v1.database.Database")
161-
def test__session_checkout(self, mock_database):
162-
pool = self._make_pool()
163-
mock_database._pool = pool
164-
connection = Connection(INSTANCE, mock_database)
164+
actual_session = connection._session_checkout()
165+
166+
self.assertEqual(actual_session, expected_session)
167+
sessions_manager.get_session.assert_called_once_with(TransactionType.READ_ONLY)
168+
169+
def test__session_checkout_read_write(self):
170+
connection = build_connection(read_only=False)
171+
database = connection._database
172+
sessions_manager = database._sessions_manager
173+
174+
expected_session = build_session(database=database)
175+
sessions_manager.get_session = mock.MagicMock(return_value=expected_session)
165176

166-
connection._session_checkout()
167-
pool.get.assert_called_once_with()
168-
self.assertEqual(connection._session, pool.get.return_value)
177+
actual_session = connection._session_checkout()
169178

170-
connection._session = "db_session"
171-
connection._session_checkout()
172-
self.assertEqual(connection._session, "db_session")
179+
self.assertEqual(actual_session, expected_session)
180+
sessions_manager.get_session.assert_called_once_with(TransactionType.READ_WRITE)
173181

174182
def test_session_checkout_database_error(self):
175183
connection = Connection(INSTANCE)
176184

177185
with pytest.raises(ValueError):
178186
connection._session_checkout()
179187

180-
@mock.patch("google.cloud.spanner_v1.database.Database")
181-
def test__release_session(self, mock_database):
182-
pool = self._make_pool()
183-
mock_database._pool = pool
184-
connection = Connection(INSTANCE, mock_database)
185-
connection._session = "session"
188+
def test__release_session(self):
189+
connection = build_connection()
190+
sessions_manager = connection._database._sessions_manager
191+
192+
session = connection._session = build_session(database=connection._database)
193+
put_session = sessions_manager.put_session = mock.MagicMock()
186194

187195
connection._release_session()
188-
pool.put.assert_called_once_with("session")
189-
self.assertIsNone(connection._session)
196+
197+
put_session.assert_called_once_with(session)
190198

191199
def test_release_session_database_error(self):
192200
connection = Connection(INSTANCE)

0 commit comments

Comments
 (0)