diff --git a/gcloud/datastore/client.py b/gcloud/datastore/client.py index f1c7be41da72..b1aa8006fb65 100644 --- a/gcloud/datastore/client.py +++ b/gcloud/datastore/client.py @@ -23,13 +23,15 @@ from gcloud.datastore.key import Key from gcloud.datastore.query import Query from gcloud.datastore.transaction import Transaction +from gcloud.datastore._implicit_environ import _determine_default_dataset_id +from gcloud.datastore._implicit_environ import get_connection class Client(object): """Convenience wrapper for invoking APIs/factories w/ a dataset ID. :type dataset_id: string - :param dataset_id: (required) dataset ID to pass to proxied API methods. + :param dataset_id: (optional) dataset ID to pass to proxied API methods. :type namespace: string :param namespace: (optional) namespace to pass to proxied API methods. @@ -38,12 +40,15 @@ class Client(object): :param connection: (optional) connection to pass to proxied API methods """ - def __init__(self, dataset_id, namespace=None, connection=None): + def __init__(self, dataset_id=None, namespace=None, connection=None): + dataset_id = _determine_default_dataset_id(dataset_id) if dataset_id is None: - raise ValueError('dataset_id required') + raise EnvironmentError('Dataset ID could not be inferred.') self.dataset_id = dataset_id - self.namespace = namespace + if connection is None: + connection = get_connection() self.connection = connection + self.namespace = namespace def get(self, key, missing=None, deferred=None): """Proxy to :func:`gcloud.datastore.api.get`. diff --git a/gcloud/datastore/test_client.py b/gcloud/datastore/test_client.py index 3a5965ca7778..1405a791ccf7 100644 --- a/gcloud/datastore/test_client.py +++ b/gcloud/datastore/test_client.py @@ -18,29 +18,45 @@ class TestClient(unittest2.TestCase): DATASET_ID = 'DATASET' + CONNECTION = object() def _getTargetClass(self): from gcloud.datastore.client import Client return Client - def _makeOne(self, dataset_id=DATASET_ID, namespace=None, connection=None): - return self._getTargetClass()(dataset_id, namespace=namespace, + def _makeOne(self, dataset_id=DATASET_ID, namespace=None, + connection=CONNECTION): + return self._getTargetClass()(dataset_id=dataset_id, + namespace=namespace, connection=connection) - def test_ctor_w_dataset_id_None(self): - self.assertRaises(ValueError, self._makeOne, None) + def test_ctor_w_dataset_id_no_environ(self): + self.assertRaises(EnvironmentError, self._makeOne, None) - def test_ctor_w_dataset_id_no_connection(self): - client = self._makeOne() - self.assertEqual(client.dataset_id, self.DATASET_ID) + def test_ctor_w_implicit_inputs(self): + from gcloud._testing import _Monkey + from gcloud.datastore import client as _MUT + OTHER = 'other' + conn = object() + klass = self._getTargetClass() + with _Monkey(_MUT, + _determine_default_dataset_id=lambda x: x or OTHER, + get_connection=lambda: conn): + client = klass() + self.assertEqual(client.dataset_id, OTHER) + self.assertEqual(client.namespace, None) + self.assertTrue(client.connection is conn) def test_ctor_w_explicit_inputs(self): + OTHER = 'other' + NAMESPACE = 'namespace' conn = object() - namespace = object() - client = self._makeOne(namespace=namespace, connection=conn) - self.assertEqual(client.dataset_id, self.DATASET_ID) + client = self._makeOne(dataset_id=OTHER, + namespace=NAMESPACE, + connection=conn) + self.assertEqual(client.dataset_id, OTHER) + self.assertEqual(client.namespace, NAMESPACE) self.assertTrue(client.connection is conn) - self.assertTrue(client.namespace is namespace) def test_get_defaults(self): from gcloud.datastore import client as MUT @@ -60,7 +76,7 @@ def _get(*args, **kw): self.assertEqual(_called_with[0][0], (key,)) self.assertTrue(_called_with[0][1]['missing'] is None) self.assertTrue(_called_with[0][1]['deferred'] is None) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_get_explicit(self): @@ -103,7 +119,7 @@ def _get_multi(*args, **kw): self.assertEqual(_called_with[0][0], ([key],)) self.assertTrue(_called_with[0][1]['missing'] is None) self.assertTrue(_called_with[0][1]['deferred'] is None) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_get_multi_explicit(self): @@ -144,7 +160,7 @@ def _put(*args, **kw): client.put(entity) self.assertEqual(_called_with[0][0], (entity,)) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_put_w_connection(self): @@ -182,7 +198,7 @@ def _put_multi(*args, **kw): client.put_multi([entity]) self.assertEqual(_called_with[0][0], ([entity],)) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_put_multi_w_connection(self): @@ -220,7 +236,7 @@ def _delete(*args, **kw): client.delete(key) self.assertEqual(_called_with[0][0], (key,)) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_delete_w_connection(self): @@ -257,7 +273,7 @@ def _delete_multi(*args, **kw): client.delete_multi([key]) self.assertEqual(_called_with[0][0], ([key],)) - self.assertTrue(_called_with[0][1]['connection'] is None) + self.assertTrue(_called_with[0][1]['connection'] is self.CONNECTION) self.assertEqual(_called_with[0][1]['dataset_id'], self.DATASET_ID) def test_delete_multi_w_connection(self): @@ -351,7 +367,8 @@ def test_batch_wo_connection(self): self.assertTrue(isinstance(batch, _Dummy)) self.assertEqual(batch.args, ()) self.assertEqual(batch.kwargs, - {'dataset_id': self.DATASET_ID, 'connection': None}) + {'dataset_id': self.DATASET_ID, + 'connection': self.CONNECTION}) def test_batch_w_connection(self): from gcloud.datastore import client as MUT @@ -378,7 +395,8 @@ def test_transaction_wo_connection(self): self.assertTrue(isinstance(xact, _Dummy)) self.assertEqual(xact.args, ()) self.assertEqual(xact.kwargs, - {'dataset_id': self.DATASET_ID, 'connection': None}) + {'dataset_id': self.DATASET_ID, + 'connection': self.CONNECTION}) def test_transaction_w_connection(self): from gcloud.datastore import client as MUT